Skip to content

Commit ab4c945

Browse files
committed
[mlir][tosa] Add support for CONV2D_BLOCK_SCALED operator
This commit adds support for an MXFP CONV2D operation, CONV2D_BLOCK_SCALED, added to the specification in arm/tosa-specification@408a5e5. This includes: - Operator definition - Addition of the EXT_MXFP_CONV extension - Verification logic for the operator - Output shape inference for the operator - Validation checks to ensure compliance with the TOSA specification. Change-Id: I7553f7796d2d156f43310108e9a69a593cdece33
1 parent 72f3995 commit ab4c945

17 files changed

+765
-91
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -512,6 +512,18 @@ extensionComplianceMap = {
512512
{{Extension::bf16},
513513
{{{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T},
514514
SpecificationVersion::V_1_0}}}}},
515+
{"tosa.conv2d_block_scaled",
516+
{{{Extension::mxfp_conv},
517+
{{{fp4e2m1T, fp8ue8m0T, fp4e2m1T, fp8ue8m0T, fp32T, fp32T},
518+
SpecificationVersion::V_1_1_DRAFT},
519+
{{fp6e2m3T, fp8ue8m0T, fp6e2m3T, fp8ue8m0T, fp32T, fp32T},
520+
SpecificationVersion::V_1_1_DRAFT},
521+
{{fp6e3m2T, fp8ue8m0T, fp6e3m2T, fp8ue8m0T, fp32T, fp32T},
522+
SpecificationVersion::V_1_1_DRAFT},
523+
{{fp8e4m3T, fp8ue8m0T, fp8e4m3T, fp8ue8m0T, fp32T, fp32T},
524+
SpecificationVersion::V_1_1_DRAFT},
525+
{{fp8e5m2T, fp8ue8m0T, fp8e5m2T, fp8ue8m0T, fp32T, fp32T},
526+
SpecificationVersion::V_1_1_DRAFT}}}}},
515527
{"tosa.conv3d",
516528
{{{Extension::int4},
517529
{{{i8T, i4T, i32T, i8T, i4T, i32T, i32T}, SpecificationVersion::V_1_0}}},

mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,7 @@ class Tosa_I32EnumAttr<string name, string description, string mnemonic,
241241
// INEXACTROUND : Adds inexact rounding support to the RESCALE operator.
242242
// DYNAMIC : Removes all Compile Time Constant state for CTC inputs.
243243
// MXFP : Microscaling formats.
244+
// MXFP_CONV : Microscaling format convolution.
244245
//===----------------------------------------------------------------------===//
245246

246247
def Tosa_NONE : I32EnumAttrCase<"none", 0>;
@@ -274,23 +275,24 @@ def Tosa_EXT_INEXACTROUND : I32EnumAttrCase<"inexactround", 10>;
274275
def Tosa_EXT_DYNAMIC : I32EnumAttrCase<"dynamic", 11>;
275276
def Tosa_EXT_MXFP : I32EnumAttrCase<"mxfp", 12>;
276277
def Tosa_EXT_INT64 : I32EnumAttrCase<"int64", 13>;
278+
def Tosa_EXT_MXFP_CONV : I32EnumAttrCase<"mxfp_conv", 14>;
277279

278280

279281
def Tosa_ExtensionAttr
280282
: Tosa_I32EnumAttr<"Extension", "supported TOSA extensions", "ext", [
281283
Tosa_EXT_NONE, Tosa_EXT_INT16, Tosa_EXT_INT4, Tosa_EXT_BF16,
282284
Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_FFT, Tosa_EXT_VARIABLE,
283285
Tosa_EXT_CONTROLFLOW, Tosa_EXT_DOUBLEROUND, Tosa_EXT_INEXACTROUND,
284-
Tosa_EXT_DYNAMIC, Tosa_EXT_MXFP, Tosa_EXT_INT64
286+
Tosa_EXT_DYNAMIC, Tosa_EXT_MXFP, Tosa_EXT_INT64, Tosa_EXT_MXFP_CONV,
285287
]> {
286288
let extraClassDeclaration = [{
287-
static llvm::SmallVector<Extension, 13> getAllValues() {
289+
static llvm::SmallVector<Extension, 14> getAllValues() {
288290
return {
289291
Extension::int16, Extension::int4, Extension::bf16,
290292
Extension::fp8e4m3, Extension::fp8e5m2, Extension::fft,
291293
Extension::variable, Extension::controlflow, Extension::doubleround,
292294
Extension::inexactround, Extension::dynamic, Extension::mxfp,
293-
Extension::int64
295+
Extension::int64, Extension::mxfp_conv
294296
};
295297
}
296298
}];

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,43 @@ def Tosa_Conv2DOp : Tosa_ConvOp<"conv2d"> {
163163
let hasVerifier = 1;
164164
}
165165

166+
//===----------------------------------------------------------------------===//
167+
// Operator: conv2d_block_scaled
168+
//===----------------------------------------------------------------------===//
169+
def Tosa_Conv2DBlockScaledOp : Tosa_InferShapedTypeOp<"conv2d_block_scaled"> {
170+
let summary = "Performs two dimensional convolution using block scaled tensors.";
171+
172+
let description = [{
173+
Performs a 2D convolution over the given input data and scales, using
174+
the weight data and scales. Implementations may choose to skip calculation
175+
of multiplies in the padding area.
176+
}];
177+
178+
let arguments = (ins
179+
Tosa_MXFPDataTensor4D:$input_data,
180+
Tosa_MXFPScaleTensor4D:$input_scale,
181+
Tosa_MXFPDataTensor4D:$weight_data,
182+
Tosa_MXFPScaleTensor4D:$weight_scale,
183+
Tosa_Tensor1D:$bias,
184+
Rank4TosaShape:$pad,
185+
Rank2TosaShape:$stride,
186+
Rank2TosaShape:$dilation,
187+
Tosa_BlockSizeAttr:$block_size
188+
);
189+
190+
let results = (outs
191+
Tosa_Tensor4D:$output
192+
);
193+
194+
list<Availability> availability = [
195+
Profile<[Tosa_PRO_FP]>,
196+
Extension<[Tosa_EXT_MXFP_CONV]>,
197+
];
198+
199+
let hasVerifier = 1;
200+
let hasCustomAssemblyFormat = 1;
201+
}
202+
166203
//===----------------------------------------------------------------------===//
167204
// Operator: conv3d
168205
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ class TosaProfileCompliance {
149149
case Extension::fp8e5m2:
150150
case Extension::fft:
151151
case Extension::mxfp:
152+
case Extension::mxfp_conv:
152153
return {Profile::pro_fp};
153154
case Extension::variable:
154155
case Extension::controlflow:

mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,8 @@ def Tosa_Tensor1Dto6D : AnyTypeOf<[
202202
def Tosa_TensorUpto4D : AnyTypeOf<[
203203
Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [0,1,2,3,4]>]>;
204204

205+
def Tosa_IndexTensor1D : AnyTypeOf<[
206+
Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_Int32, Tosa_Int64], [1]>]>;
205207
def Tosa_IndexTensor2D : AnyTypeOf<[
206208
Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_Int32, Tosa_Int64], [2]>]>;
207209

@@ -216,6 +218,14 @@ def Tosa_MXFPScaleTensor3D : AnyTypeOf<[
216218
TosaUnrankedTensorOf<[Tosa_MXFPScaleNumber]>,
217219
TosaTensorRankOf<[Tosa_MXFPScaleNumber], [3]>
218220
]>;
221+
def Tosa_MXFPDataTensor4D : AnyTypeOf<[
222+
TosaUnrankedTensorOf<[Tosa_MXFPNumber]>,
223+
TosaTensorRankOf<[Tosa_MXFPNumber], [4]>
224+
]>;
225+
def Tosa_MXFPScaleTensor4D : AnyTypeOf<[
226+
TosaUnrankedTensorOf<[Tosa_MXFPScaleNumber]>,
227+
TosaTensorRankOf<[Tosa_MXFPScaleNumber], [4]>
228+
]>;
219229
def Tosa_MXFPDataTensorAtLeast1D : AnyTypeOf<[
220230
TosaUnrankedTensorOf<[Tosa_MXFPNumber]>,
221231
TosaRankedTensorOf<[Tosa_MXFPNumber], [AtLeastRankOne]>],

mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ TosaSpecificationVersion getMinVersion(const Extension &extension) {
4343
return TosaSpecificationVersion(1, 0);
4444
case Extension::mxfp:
4545
case Extension::int64:
46+
case Extension::mxfp_conv:
4647
return TosaSpecificationVersion(1, 1);
4748
case Extension::none:
4849
return TosaSpecificationVersion(0, 0);

0 commit comments

Comments
 (0)