-
Notifications
You must be signed in to change notification settings - Fork 15.5k
[MLIR] Extend the extractvalue fold method #172297
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-llvm Author: Vadim Curcă (VadimCurca) ChangesExtend the Full diff: https://github.com/llvm/llvm-project/pull/172297.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 5b819485b1be4..3dcc3f1373a5c 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -1898,6 +1898,23 @@ static Type getInsertExtractValueElementType(Type llvmType,
return llvmType;
}
+/// Extract the element at `index` from `attr` if it is an `ElementsAttr` or
+/// `ArrayAttr`. Returns `nullptr` if `attr` is not one of those types or if the
+/// `index` is out of bounds.
+static Attribute extractElementAt(Attribute attr, size_t index) {
+ if (auto elementsAttr = dyn_cast<ElementsAttr>(attr)) {
+ if (index < static_cast<size_t>(elementsAttr.getNumElements()))
+ return elementsAttr.getValues<Attribute>()[index];
+ return nullptr;
+ }
+ if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
+ if (index < arrayAttr.getValue().size())
+ return arrayAttr[index];
+ return nullptr;
+ }
+ return nullptr;
+}
+
OpFoldResult LLVM::ExtractValueOp::fold(FoldAdaptor adaptor) {
if (auto extractValueOp = getContainer().getDefiningOp<ExtractValueOp>()) {
SmallVector<int64_t, 4> newPos(extractValueOp.getPosition());
@@ -1907,22 +1924,11 @@ OpFoldResult LLVM::ExtractValueOp::fold(FoldAdaptor adaptor) {
return getResult();
}
- {
- DenseElementsAttr constval;
- matchPattern(getContainer(), m_Constant(&constval));
- if (constval && constval.getElementType() == getType()) {
- if (isa<SplatElementsAttr>(constval))
- return constval.getSplatValue<Attribute>();
- if (getPosition().size() == 1)
- return constval.getValues<Attribute>()[getPosition()[0]];
- }
- }
-
- auto insertValueOp = getContainer().getDefiningOp<InsertValueOp>();
+ Operation *container = getContainer().getDefiningOp();
OpFoldResult result = {};
ArrayRef<int64_t> extractPos = getPosition();
bool switchedToInsertedValue = false;
- while (insertValueOp) {
+ while (auto insertValueOp = dyn_cast_if_present<InsertValueOp>(container)) {
ArrayRef<int64_t> insertPos = insertValueOp.getPosition();
auto extractPosSize = extractPos.size();
auto insertPosSize = insertPos.size();
@@ -1945,7 +1951,7 @@ OpFoldResult LLVM::ExtractValueOp::fold(FoldAdaptor adaptor) {
// In the above example, %4 is folded to %arg1.
if (extractPosSize > insertPosSize &&
extractPos.take_front(insertPosSize) == insertPos) {
- insertValueOp = insertValueOp.getValue().getDefiningOp<InsertValueOp>();
+ container = insertValueOp.getValue().getDefiningOp();
extractPos = extractPos.drop_front(insertPosSize);
switchedToInsertedValue = true;
continue;
@@ -1975,9 +1981,32 @@ OpFoldResult LLVM::ExtractValueOp::fold(FoldAdaptor adaptor) {
getContainerMutable().assign(insertValueOp.getContainer());
result = getResult();
}
- insertValueOp = insertValueOp.getContainer().getDefiningOp<InsertValueOp>();
+ container = insertValueOp.getContainer().getDefiningOp();
+ }
+ if (!container)
+ return result;
+
+ Attribute containerAttr;
+ if (!matchPattern(container, m_Constant(&containerAttr)))
+ return nullptr;
+ for (int64_t pos : extractPos) {
+ Attribute attrElement = extractElementAt(containerAttr, pos);
+
+ // It is possible to fail to extract an element from the container and still
+ // fold the operation to a constant. For example:
+ // ```
+ // %container = llvm.mlir.zero : !llvm.struct<(i8, i32)>
+ // %result = llvm.extractvalue %container[0] : !llvm.struct<(i8, i32)>
+ // ```
+ // In this case, `containerAttr` is an `LLVM::ZeroAttr` that does not
+ // contain any nested elements, yet the operation can be folded to a zero
+ // constant.
+ if (!attrElement)
+ return containerAttr;
+
+ containerAttr = attrElement;
}
- return result;
+ return containerAttr;
}
LogicalResult ExtractValueOp::verify() {
diff --git a/mlir/test/Dialect/LLVMIR/canonicalize.mlir b/mlir/test/Dialect/LLVMIR/canonicalize.mlir
index 755e3a3a5fa09..8303afc9eb033 100644
--- a/mlir/test/Dialect/LLVMIR/canonicalize.mlir
+++ b/mlir/test/Dialect/LLVMIR/canonicalize.mlir
@@ -112,10 +112,10 @@ llvm.func @fold_extract_extractvalue(%arr: !llvm.struct<(i64, array<1 x ptr<1>>)
// -----
-// CHECK-LABEL: fold_extract_const
+// CHECK-LABEL: fold_extract_const_array
// CHECK-NOT: extractvalue
// CHECK: llvm.mlir.constant(5.000000e-01 : f64)
-llvm.func @fold_extract_const() -> f64 {
+llvm.func @fold_extract_const_array() -> f64 {
%a = llvm.mlir.constant(dense<[-8.900000e+01, 5.000000e-01]> : tensor<2xf64>) : !llvm.array<2 x f64>
%b = llvm.extractvalue %a[1] : !llvm.array<2 x f64>
llvm.return %b : f64
@@ -123,6 +123,17 @@ llvm.func @fold_extract_const() -> f64 {
// -----
+// CHECK-LABEL: fold_extract_const_struct
+llvm.func @fold_extract_const_struct() -> i32 {
+ // CHECK-NOT: extractvalue
+ // CHECK: llvm.mlir.constant(2 : i32)
+ %a = llvm.mlir.constant([1 : i16, 2 : i32]) : !llvm.struct<(i16, i32)>
+ %b = llvm.extractvalue %a[1] : !llvm.struct<(i16, i32)>
+ llvm.return %b : i32
+}
+
+// -----
+
// CHECK-LABEL: fold_extract_splat
// CHECK-NOT: extractvalue
// CHECK: llvm.mlir.constant(-8.900000e+01 : f64)
@@ -134,6 +145,90 @@ llvm.func @fold_extract_splat() -> f64 {
// -----
+// CHECK-LABEL: fold_extract_splat_nested
+llvm.func @fold_extract_splat_nested() -> i32 {
+ // CHECK-NOT: extractvalue
+ // CHECK: llvm.mlir.constant(1 : i32)
+ %a = llvm.mlir.constant(dense<(0, 1)> : tensor<2xcomplex<i32>>) : !llvm.array<2 x !llvm.struct<(i32, i32)>>
+ %b = llvm.extractvalue %a[1, 1] : !llvm.array<2 x !llvm.struct<(i32, i32)>>
+ llvm.return %b : i32
+}
+
+// -----
+
+// CHECK-LABEL: fold_extract_sparse
+llvm.func @fold_extract_sparse() -> f32 {
+ // CHECK-NOT: extractvalue
+ // CHECK-DAG: %[[C0:.*]] = llvm.mlir.constant(0.000000e+00 : f32)
+ // CHECK-DAG: %[[C42:.*]] = llvm.mlir.constant(4.200000e+01 : f32)
+ %0 = llvm.mlir.constant(sparse<[0], [4.2e+01]> : tensor<4xf32>) : !llvm.array<4 x f32>
+ %1 = llvm.extractvalue %0[0] : !llvm.array<4 x f32>
+ %2 = llvm.extractvalue %0[1] : !llvm.array<4 x f32>
+ // CHECK: llvm.fadd %[[C42]], %[[C0]]
+ %3 = llvm.fadd %1, %2 : f32
+ llvm.return %3 : f32
+}
+
+// -----
+
+// CHECK-LABEL: fold_zero
+llvm.func @fold_zero() -> i32 {
+ // CHECK-NOT: insertvalue
+ // CHECK-NOT: extractvalue
+ // CHECK: %[[ZERO:.*]] = llvm.mlir.zero : i32
+ %0 = llvm.mlir.zero : !llvm.struct<(i16, i32)>
+
+ %1 = llvm.mlir.undef : !llvm.array<2 x !llvm.struct<(i16, i32)>>
+ %2 = llvm.insertvalue %0, %1[0] : !llvm.array<2 x !llvm.struct<(i16, i32)>>
+ %3 = llvm.extractvalue %2[0, 1] : !llvm.array<2 x !llvm.struct<(i16, i32)>>
+ // CHECK: llvm.return %[[ZERO]]
+ llvm.return %3 : i32
+}
+
+// -----
+
+llvm.func @use_struct(!llvm.struct<(i16, i32)>)
+
+// CHECK-LABEL: fold_undef
+llvm.func @fold_undef() -> i32 {
+ // CHECK-NOT: insertvalue
+ // CHECK-NOT: extractvalue
+ // CHECK-DAG: %[[UNDEF_I32:.*]] = llvm.mlir.undef : i32
+ // CHECK-DAG: %[[UNDEF_STRUCT:.*]] = llvm.mlir.undef : !llvm.struct<(i16, i32)>
+ %0 = llvm.mlir.undef : !llvm.struct<(i8, !llvm.struct<(i16, i32)>)>
+
+ %1 = llvm.extractvalue %0[1] : !llvm.struct<(i8, !llvm.struct<(i16, i32)>)>
+ // CHECK: llvm.call @use_struct(%[[UNDEF_STRUCT]])
+ llvm.call @use_struct(%1) : (!llvm.struct<(i16, i32)>) -> ()
+
+ %2 = llvm.extractvalue %0[1, 1] : !llvm.struct<(i8, !llvm.struct<(i16, i32)>)>
+ // CHECK: llvm.return %[[UNDEF_I32]]
+ llvm.return %2 : i32
+}
+
+// -----
+
+llvm.func @use_array(!llvm.array<8 x f32>)
+
+// CHECK-LABEL: fold_poison
+llvm.func @fold_poison() -> f32 {
+ // CHECK-NOT: insertvalue
+ // CHECK-NOT: extractvalue
+ // CHECK-DAG: %[[POISON_F32:.*]] = llvm.mlir.poison : f32
+ // CHECK-DAG: %[[POISON_ARRAY:.*]] = llvm.mlir.poison : !llvm.array<8 x f32>
+ %0 = llvm.mlir.poison : !llvm.array<2 x !llvm.array<8 x f32>>
+
+ %1 = llvm.extractvalue %0[1] : !llvm.array<2 x !llvm.array<8 x f32>>
+ // CHECK: llvm.call @use_array(%[[POISON_ARRAY]])
+ llvm.call @use_array(%1) : (!llvm.array<8 x f32>) -> ()
+
+ %2 = llvm.extractvalue %0[1, 1] : !llvm.array<2 x !llvm.array<8 x f32>>
+ // CHECK: llvm.return %[[POISON_F32]]
+ llvm.return %2 : f32
+}
+
+// -----
+
// CHECK-LABEL: fold_bitcast
// CHECK-SAME: %[[ARG:[[:alnum:]]+]]
// CHECK-NEXT: llvm.return %[[ARG]]
|
Dinistro
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't remember if we support any kind of nested struct constants, but might be nice to check if this is actually foldable as well (assuming they exit).
Otherwise, LGTM! Thanks for the fix.
gysit
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for extending this!
Extend the `extractvalue` fold method to support extracting from constant containers, such as `llvm.mlir.zero`, `llvm.mlir.undef`, `llvm.mlir.poison`, and `llvm.mlir.constant` holding `ElementsAttr` or `ArrayAttr`.
| return result; | ||
|
|
||
| Attribute containerAttr; | ||
| if (!matchPattern(container, m_Constant(&containerAttr))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not following the logic very well: seems to me that we're at the end of a chain of insertValueOps and you're then just looking at the initial constant. However the initial value can have been overridden by one of the insert, can't it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I might be misunderstanding your comment, but note that container here is not strictly the initial constant. It is the container at the end of the insertValueOp chain (if such a chain exists). If there isn't a chain, then container corresponds to the initial constant.
63054e5 to
f7bb67e
Compare
gysit
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the change.
LGTM from my end.
Extend the
extractvaluefold method to support extracting from constant containers, such asllvm.mlir.zero,llvm.mlir.undef,llvm.mlir.poison, andllvm.mlir.constantholdingElementsAttrorArrayAttr.