Skip to content

Conversation

@lhutton1
Copy link
Contributor

This commit adds support for an MXFP CONV2D operation, CONV2D_BLOCK_SCALED, added to the specification in arm/tosa-specification@408a5e5.

This includes:

  • Operator definition
  • Addition of the EXT_MXFP_CONV extension
  • Verification logic for the operator
  • Output shape inference for the operator
  • Validation checks to ensure compliance with the TOSA specification.

@llvmbot
Copy link
Member

llvmbot commented Dec 15, 2025

@llvm/pr-subscribers-mlir-tosa

@llvm/pr-subscribers-mlir

Author: Luke Hutton (lhutton1)

Changes

This commit adds support for an MXFP CONV2D operation, CONV2D_BLOCK_SCALED, added to the specification in arm/tosa-specification@408a5e5.

This includes:

  • Operator definition
  • Addition of the EXT_MXFP_CONV extension
  • Verification logic for the operator
  • Output shape inference for the operator
  • Validation checks to ensure compliance with the TOSA specification.

Patch is 70.75 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/172294.diff

17 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc (+7)
  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td (+5-3)
  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+37)
  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h (+1)
  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td (+10)
  • (modified) mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp (+1)
  • (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+309-79)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp (+13)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp (+52-2)
  • (modified) mlir/test/Dialect/Tosa/invalid.mlir (+12-1)
  • (modified) mlir/test/Dialect/Tosa/invalid_extension.mlir (+13-2)
  • (modified) mlir/test/Dialect/Tosa/level_check.mlir (+71)
  • (modified) mlir/test/Dialect/Tosa/ops.mlir (+20)
  • (modified) mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir (+13-3)
  • (modified) mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir (+48)
  • (modified) mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir (+12-1)
  • (modified) mlir/test/Dialect/Tosa/verifier.mlir (+132)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
index e23827f8aabf2..e452723f193b9 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
@@ -512,6 +512,13 @@ extensionComplianceMap = {
       {{Extension::bf16},
        {{{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T},
          SpecificationVersion::V_1_0}}}}},
+    {"tosa.conv2d_block_scaled",
+     {{{Extension::mxfp_conv},
+       {{{fp4e2m1T, fp8ue8m0T, fp4e2m1T, fp8ue8m0T, fp32T, fp32T}, SpecificationVersion::V_1_1_DRAFT},
+        {{fp6e2m3T, fp8ue8m0T, fp6e2m3T, fp8ue8m0T, fp32T, fp32T}, SpecificationVersion::V_1_1_DRAFT},
+        {{fp6e3m2T, fp8ue8m0T, fp6e3m2T, fp8ue8m0T, fp32T, fp32T}, SpecificationVersion::V_1_1_DRAFT},
+        {{fp8e4m3T, fp8ue8m0T, fp8e4m3T, fp8ue8m0T, fp32T, fp32T}, SpecificationVersion::V_1_1_DRAFT},
+        {{fp8e5m2T, fp8ue8m0T, fp8e5m2T, fp8ue8m0T, fp32T, fp32T}, SpecificationVersion::V_1_1_DRAFT}}}}},
     {"tosa.conv3d",
      {{{Extension::int4},
        {{{i8T, i4T, i32T, i8T, i4T, i32T, i32T}, SpecificationVersion::V_1_0}}},
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index cc23955f31f23..421abc939b2e0 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -241,6 +241,7 @@ class Tosa_I32EnumAttr<string name, string description, string mnemonic,
 // INEXACTROUND : Adds inexact rounding support to the RESCALE operator.
 // DYNAMIC      : Removes all Compile Time Constant state for CTC inputs.
 // MXFP         : Microscaling formats.
+// MXFP_CONV    : Microscaling format convolution.
 //===----------------------------------------------------------------------===//
 
 def Tosa_NONE : I32EnumAttrCase<"none", 0>;
@@ -274,6 +275,7 @@ def Tosa_EXT_INEXACTROUND : I32EnumAttrCase<"inexactround", 10>;
 def Tosa_EXT_DYNAMIC      : I32EnumAttrCase<"dynamic", 11>;
 def Tosa_EXT_MXFP         : I32EnumAttrCase<"mxfp", 12>;
 def Tosa_EXT_INT64        : I32EnumAttrCase<"int64", 13>;
+def Tosa_EXT_MXFP_CONV    : I32EnumAttrCase<"mxfp_conv", 14>;
 
 
 def Tosa_ExtensionAttr
@@ -281,16 +283,16 @@ def Tosa_ExtensionAttr
       Tosa_EXT_NONE, Tosa_EXT_INT16, Tosa_EXT_INT4, Tosa_EXT_BF16,
       Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_FFT, Tosa_EXT_VARIABLE,
       Tosa_EXT_CONTROLFLOW, Tosa_EXT_DOUBLEROUND, Tosa_EXT_INEXACTROUND,
-      Tosa_EXT_DYNAMIC, Tosa_EXT_MXFP, Tosa_EXT_INT64
+      Tosa_EXT_DYNAMIC, Tosa_EXT_MXFP, Tosa_EXT_INT64, Tosa_EXT_MXFP_CONV,
     ]> {
   let extraClassDeclaration = [{
-    static llvm::SmallVector<Extension, 13> getAllValues() {
+    static llvm::SmallVector<Extension, 14> getAllValues() {
       return {
         Extension::int16, Extension::int4, Extension::bf16,
         Extension::fp8e4m3, Extension::fp8e5m2, Extension::fft,
         Extension::variable, Extension::controlflow, Extension::doubleround,
         Extension::inexactround, Extension::dynamic, Extension::mxfp,
-        Extension::int64
+        Extension::int64, Extension::mxfp_conv
       };
     }
   }];
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 370ce8c161d0b..edd8f0fc266bb 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -163,6 +163,43 @@ def Tosa_Conv2DOp : Tosa_ConvOp<"conv2d"> {
   let hasVerifier = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// Operator: conv2d_block_scaled
+//===----------------------------------------------------------------------===//
+def Tosa_Conv2DBlockScaledOp : Tosa_InferShapedTypeOp<"conv2d_block_scaled"> {
+  let summary = "Performs two dimensional convolution using block scaled tensors.";
+
+  let description = [{
+    Performs a 2D convolution over the given input data and scales, using
+    the weight data and scales. Implementations may choose to skip calculation
+    of multiplies in the padding area.
+  }];
+
+  let arguments = (ins
+    Tosa_MXFPDataTensor4D:$input_data,
+    Tosa_MXFPScaleTensor4D:$input_scale,
+    Tosa_MXFPDataTensor4D:$weight_data,
+    Tosa_MXFPScaleTensor4D:$weight_scale,
+    Tosa_Tensor1D:$bias,
+    Rank4TosaShape:$pad,
+    Rank2TosaShape:$stride,
+    Rank2TosaShape:$dilation,
+    Tosa_BlockSizeAttr:$block_size
+  );
+
+  let results = (outs
+    Tosa_Tensor4D:$output
+  );
+
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_FP]>,
+    Extension<[Tosa_EXT_MXFP_CONV]>,
+  ];
+
+  let hasVerifier = 1;
+  let hasCustomAssemblyFormat = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // Operator: conv3d
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
index ea58f49b64c44..5c77bd701e416 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
@@ -149,6 +149,7 @@ class TosaProfileCompliance {
     case Extension::fp8e5m2:
     case Extension::fft:
     case Extension::mxfp:
+    case Extension::mxfp_conv:
       return {Profile::pro_fp};
     case Extension::variable:
     case Extension::controlflow:
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index 266a9e3a7d946..0468ca29e10ac 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -202,6 +202,8 @@ def Tosa_Tensor1Dto6D : AnyTypeOf<[
 def Tosa_TensorUpto4D : AnyTypeOf<[
   Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [0,1,2,3,4]>]>;
 
+def Tosa_IndexTensor1D : AnyTypeOf<[
+  Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_Int32, Tosa_Int64], [1]>]>;
 def Tosa_IndexTensor2D : AnyTypeOf<[
   Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_Int32, Tosa_Int64], [2]>]>;
 
@@ -216,6 +218,14 @@ def Tosa_MXFPScaleTensor3D : AnyTypeOf<[
   TosaUnrankedTensorOf<[Tosa_MXFPScaleNumber]>,
   TosaTensorRankOf<[Tosa_MXFPScaleNumber], [3]>
 ]>;
+def Tosa_MXFPDataTensor4D : AnyTypeOf<[
+  TosaUnrankedTensorOf<[Tosa_MXFPNumber]>,
+  TosaTensorRankOf<[Tosa_MXFPNumber], [4]>
+]>;
+def Tosa_MXFPScaleTensor4D : AnyTypeOf<[
+  TosaUnrankedTensorOf<[Tosa_MXFPScaleNumber]>,
+  TosaTensorRankOf<[Tosa_MXFPScaleNumber], [4]>
+]>;
 def Tosa_MXFPDataTensorAtLeast1D : AnyTypeOf<[
   TosaUnrankedTensorOf<[Tosa_MXFPNumber]>,
   TosaRankedTensorOf<[Tosa_MXFPNumber], [AtLeastRankOne]>],
diff --git a/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp
index eb47e85cf9b0b..2e0a0d85d7dbe 100644
--- a/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp
@@ -43,6 +43,7 @@ TosaSpecificationVersion getMinVersion(const Extension &extension) {
     return TosaSpecificationVersion(1, 0);
   case Extension::mxfp:
   case Extension::int64:
+  case Extension::mxfp_conv:
     return TosaSpecificationVersion(1, 1);
   case Extension::none:
     return TosaSpecificationVersion(0, 0);
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index bead774620a4f..6382c28ed4312 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -550,6 +550,15 @@ void CastToBlockScaledOp::print(OpAsmPrinter &parser) {
   printWithEnumHandling(parser, *this);
 }
 
+ParseResult Conv2DBlockScaledOp::parse(OpAsmParser &parser,
+                                       OperationState &result) {
+  return parseWithEnumHandling<tosa::BlockSize>(parser, result);
+}
+
+void Conv2DBlockScaledOp::print(OpAsmPrinter &parser) {
+  printWithEnumHandling(parser, *this);
+}
+
 //===----------------------------------------------------------------------===//
 // Tosa utilities.
 //===----------------------------------------------------------------------===//
@@ -612,6 +621,55 @@ unsigned mlir::tosa::getBitWidth(Type type) {
   return type.getIntOrFloatBitWidth();
 }
 
+// Update dim size if current dim is dynamic, otherwise raise an error if sizes
+// do not match
+LogicalResult tryUpdateDimOrFailure(Operation *op, int64_t &currDim,
+                                    const int64_t newDim,
+                                    const StringRef operandName,
+                                    const StringRef dimName) {
+  if (ShapedType::isDynamic(currDim)) {
+    currDim = newDim;
+    return success();
+  } else if (ShapedType::isStatic(newDim) && currDim != newDim) {
+    return op->emitOpError("expected ")
+           << dimName << " of " << operandName << " to match size " << currDim
+           << ", got " << newDim;
+  }
+  return success();
+}
+
+LogicalResult verifyConvOutputSize(
+    Operation *op, const int64_t inputSize, const int64_t kernelSize,
+    const int64_t outputSize, const int64_t padBefore, const int64_t padAfter,
+    const int64_t stride, const int64_t dilation, const llvm::StringRef dimName,
+    const llvm::StringRef dimAxis, const llvm::StringRef padBeforeName,
+    const llvm::StringRef padAfterName) {
+  if (inputSize == ShapedType::kDynamic || kernelSize == ShapedType::kDynamic)
+    return success();
+
+  // ERROR_IF: O != idiv_check(I - 1 + pa + pb - (K - 1) * d, s) + 1
+
+  const std::optional<int64_t> calculatedOutSizeMinusOne = idivCheck(
+      inputSize - 1 + padBefore + padAfter - (kernelSize - 1) * dilation,
+      stride);
+  if (!calculatedOutSizeMinusOne.has_value())
+    return op->emitOpError("expected input_")
+           << dimName << " - 1 + pad_" << padBeforeName << " + pad_"
+           << padAfterName << " - (kernel_" << dimName << " - 1) * dilation_"
+           << dimAxis << " to be wholly divisible by stride_" << dimAxis
+           << ", got (" << inputSize << " - 1 + " << padBefore << " + "
+           << padAfter << " - (" << kernelSize << " - 1) * " << dilation
+           << ") / " << stride;
+
+  const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
+  if (outputSize != ShapedType::kDynamic && calculatedOutSize != outputSize)
+    return op->emitOpError("calculated output ")
+           << dimName << " did not match expected: "
+           << "calculated=" << calculatedOutSize << ", expected=" << outputSize;
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // TOSA Operator Verifiers.
 //===----------------------------------------------------------------------===//
@@ -791,53 +849,16 @@ static LogicalResult verifyConvOpErrorIf(T op) {
       llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
 
   if (inputType && weightType) {
-    const auto verifyOutputSize =
-        [&op](const int64_t inputSize, const int64_t kernelSize,
-              const int64_t outputSize, const int64_t padBefore,
-              const int64_t padAfter, const int64_t stride,
-              const int64_t dilation, const llvm::StringRef dimName,
-              const llvm::StringRef dimAxis,
-              const llvm::StringRef padBeforeName,
-              const llvm::StringRef padAfterName) -> LogicalResult {
-      if (inputSize == ShapedType::kDynamic ||
-          kernelSize == ShapedType::kDynamic)
-        return success();
-
-      // ERROR_IF: O != idiv_check(I - 1 + pa + pb - (K - 1) * d, s) + 1
-
-      const std::optional<int64_t> calculatedOutSizeMinusOne = idivCheck(
-          inputSize - 1 + padBefore + padAfter - (kernelSize - 1) * dilation,
-          stride);
-      if (!calculatedOutSizeMinusOne.has_value())
-        return op.emitOpError("expected input_")
-               << dimName << " - 1 + pad_" << padBeforeName << " + pad_"
-               << padAfterName << " - (kernel_" << dimName
-               << " - 1) * dilation_" << dimAxis
-               << " to be wholly divisible by stride_" << dimAxis << ", got ("
-               << inputSize << " - 1 + " << padBefore << " + " << padAfter
-               << " - (" << kernelSize << " - 1) * " << dilation << ") / "
-               << stride;
-
-      const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
-      if (outputSize != ShapedType::kDynamic && calculatedOutSize != outputSize)
-        return op.emitOpError("calculated output ")
-               << dimName << " did not match expected: "
-               << "calculated=" << calculatedOutSize
-               << ", expected=" << outputSize;
-
-      return success();
-    };
-
     // input = [_,IH,IW,_], weight = [_,KH,KW,_], output = [_,OH,OW,_]
     if constexpr (std::is_same<T, tosa::Conv2DOp>::value) {
-      if (failed(verifyOutputSize(
-              inputType.getDimSize(1), weightType.getDimSize(1),
+      if (failed(verifyConvOutputSize(
+              op, inputType.getDimSize(1), weightType.getDimSize(1),
               outputType.getDimSize(1), padding[0], padding[1], strides[0],
               dilations[0], "height", "y", "top", "bottom")))
         return failure();
 
-      if (failed(verifyOutputSize(
-              inputType.getDimSize(2), weightType.getDimSize(2),
+      if (failed(verifyConvOutputSize(
+              op, inputType.getDimSize(2), weightType.getDimSize(2),
               outputType.getDimSize(2), padding[2], padding[3], strides[1],
               dilations[1], "width", "x", "left", "right")))
         return failure();
@@ -845,14 +866,14 @@ static LogicalResult verifyConvOpErrorIf(T op) {
 
     // input = [_,IH,IW,_], weight = [KH,KW,_,_], output = [_,OH,OW,_]
     if constexpr (std::is_same<T, tosa::DepthwiseConv2DOp>::value) {
-      if (failed(verifyOutputSize(
-              inputType.getDimSize(1), weightType.getDimSize(0),
+      if (failed(verifyConvOutputSize(
+              op, inputType.getDimSize(1), weightType.getDimSize(0),
               outputType.getDimSize(1), padding[0], padding[1], strides[0],
               dilations[0], "height", "y", "top", "bottom")))
         return failure();
 
-      if (failed(verifyOutputSize(
-              inputType.getDimSize(2), weightType.getDimSize(1),
+      if (failed(verifyConvOutputSize(
+              op, inputType.getDimSize(2), weightType.getDimSize(1),
               outputType.getDimSize(2), padding[2], padding[3], strides[1],
               dilations[1], "width", "x", "left", "right")))
         return failure();
@@ -860,20 +881,20 @@ static LogicalResult verifyConvOpErrorIf(T op) {
 
     // input = [_,ID,IH,IW,_], weight = [_,KD,KH,KW,_], output = [_,OD,OH,OW,_]
     if constexpr (std::is_same<T, tosa::Conv3DOp>::value) {
-      if (failed(verifyOutputSize(
-              inputType.getDimSize(1), weightType.getDimSize(1),
+      if (failed(verifyConvOutputSize(
+              op, inputType.getDimSize(1), weightType.getDimSize(1),
               outputType.getDimSize(1), padding[0], padding[1], strides[0],
               dilations[0], "depth", "d", "front", "back")))
         return failure();
 
-      if (failed(verifyOutputSize(
-              inputType.getDimSize(2), weightType.getDimSize(2),
+      if (failed(verifyConvOutputSize(
+              op, inputType.getDimSize(2), weightType.getDimSize(2),
               outputType.getDimSize(2), padding[2], padding[3], strides[1],
               dilations[1], "height", "y", "top", "bottom")))
         return failure();
 
-      if (failed(verifyOutputSize(
-              inputType.getDimSize(3), weightType.getDimSize(3),
+      if (failed(verifyConvOutputSize(
+              op, inputType.getDimSize(3), weightType.getDimSize(3),
               outputType.getDimSize(3), padding[4], padding[5], strides[2],
               dilations[2], "width", "x", "left", "right")))
         return failure();
@@ -1954,20 +1975,6 @@ LogicalResult MatmulTBlockScaledOp::verify() {
                                     "B_data")))
     return failure();
 
-  auto tryUpdateDimOrFailure = [&](int64_t &currDim, const int64_t newDim,
-                                   const StringRef operandName,
-                                   const StringRef dimName) -> LogicalResult {
-    if (ShapedType::isDynamic(currDim)) {
-      currDim = newDim;
-      return success();
-    } else if (ShapedType::isStatic(newDim) && currDim != newDim) {
-      return emitOpError("expected ")
-             << dimName << " of " << operandName << " to match size " << currDim
-             << ", got " << newDim;
-    }
-    return success();
-  };
-
   // Verify input shape compatibility
   int64_t N = ShapedType::kDynamic;
   int64_t D = ShapedType::kDynamic;
@@ -1985,32 +1992,33 @@ LogicalResult MatmulTBlockScaledOp::verify() {
 
   const ShapeAdaptor aScaleShape = ShapeAdaptor(getAScale().getType());
   if (aScaleShape.hasRank()) {
-    if (failed(tryUpdateDimOrFailure(N, aScaleShape.getDimSize(0), "a_scale",
-                                     "batch")) ||
-        failed(tryUpdateDimOrFailure(H, aScaleShape.getDimSize(1), "a_scale",
-                                     "height")))
+    if (failed(tryUpdateDimOrFailure(*this, N, aScaleShape.getDimSize(0),
+                                     "a_scale", "batch")) ||
+        failed(tryUpdateDimOrFailure(*this, H, aScaleShape.getDimSize(1),
+                                     "a_scale", "height")))
       return failure();
     multiplesOfC = aScaleShape.getDimSize(2);
   }
 
   const ShapeAdaptor bDataShape = ShapeAdaptor(bDataType);
   if (bDataShape.hasRank()) {
-    if (failed(tryUpdateDimOrFailure(D, bDataShape.getDimSize(0), "b_data",
-                                     "batch")) ||
-        failed(tryUpdateDimOrFailure(C, bDataShape.getDimSize(2), "b_data",
-                                     "channels")))
+    if (failed(tryUpdateDimOrFailure(*this, D, bDataShape.getDimSize(0),
+                                     "b_data", "batch")) ||
+        failed(tryUpdateDimOrFailure(*this, C, bDataShape.getDimSize(2),
+                                     "b_data", "channels")))
       return failure();
     W = bDataShape.getDimSize(1);
   }
 
   const ShapeAdaptor bScaleShape = ShapeAdaptor(getBScale().getType());
   if (bScaleShape.hasRank()) {
-    if (failed(tryUpdateDimOrFailure(D, bScaleShape.getDimSize(0), "b_scale",
-                                     "batch")) ||
-        failed(tryUpdateDimOrFailure(W, bScaleShape.getDimSize(1), "b_scale",
-                                     "width")) ||
-        failed(tryUpdateDimOrFailure(multiplesOfC, bScaleShape.getDimSize(2),
-                                     "b_scale", "C/block_size")))
+    if (failed(tryUpdateDimOrFailure(*this, D, bScaleShape.getDimSize(0),
+                                     "b_scale", "batch")) ||
+        failed(tryUpdateDimOrFailure(*this, W, bScaleShape.getDimSize(1),
+                                     "b_scale", "width")) ||
+        failed(tryUpdateDimOrFailure(*this, multiplesOfC,
+                                     bScaleShape.getDimSize(2), "b_scale",
+                                     "C/block_size")))
       return failure();
   }
 
@@ -3485,6 +3493,228 @@ LogicalResult Conv2DOp::verify() {
   return success();
 }
 
+LogicalResult Conv2DBlockScaledOp::inferReturnTypeComponents(
+    MLIRContext *context, ::std::optional<Location> location,
+    Conv2DBlockScaledOp::Adaptor adaptor,
+    SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+  SmallVector<int64_t, 4> outShape(4, ShapedType::kDynamic);
+
+  int64_t inputWidth = ShapedType::kDynamic;
+  int64_t inputHeight = ShapedType::kDynamic;
+  int64_t weightWidth = ShapedType::kDynamic;
+  int64_t weightHeight = ShapedType::kDynamic;
+
+  // Input shape describes input width/height and batch.
+  const ShapeAdaptor inputDataShape(adaptor.getInputData().getType());
+  if (inputDataShape.hasRank()) {
+    outShape[0] = inputDataShape.getDimSize(0);
+    inputHeight = inputDataShape.getDimSize(1);
+    inputWidth = inputDataShape.getDimSize(2);
+  }
+  const ShapeAdaptor inputScaleShape(adaptor.getInputScale().getType());
+  if (inputScaleShape.hasRank()) {
+    outShape[0] = ShapedType::isDynamic(outShape[0])
+                      ? inputScaleShape.getDimSize(0)
+                      : outShape[0];
+    inputHeight = ShapedType::isDynamic(inputHeight)
+                      ? ...
[truncated]

@github-actions
Copy link

github-actions bot commented Dec 15, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

This commit adds support for an MXFP CONV2D operation,
CONV2D_BLOCK_SCALED, added to the specification in
arm/tosa-specification@408a5e5.

This includes:
- Operator definition
- Addition of the EXT_MXFP_CONV extension
- Verification logic for the operator
- Output shape inference for the operator
- Validation checks to ensure compliance with the TOSA specification.

Change-Id: I7553f7796d2d156f43310108e9a69a593cdece33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants