Skip to content

Conversation

@VadimCurca
Copy link
Contributor

Extend the extractvalue fold method to support extracting from constant containers, such as llvm.mlir.zero, llvm.mlir.undef, llvm.mlir.poison, and llvm.mlir.constant holding ElementsAttr or ArrayAttr.

@VadimCurca VadimCurca marked this pull request as ready for review December 15, 2025 13:57
@llvmbot
Copy link
Member

llvmbot commented Dec 15, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-llvm

Author: Vadim Curcă (VadimCurca)

Changes

Extend the extractvalue fold method to support extracting from constant containers, such as llvm.mlir.zero, llvm.mlir.undef, llvm.mlir.poison, and llvm.mlir.constant holding ElementsAttr or ArrayAttr.


Full diff: https://github.com/llvm/llvm-project/pull/172297.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp (+45-16)
  • (modified) mlir/test/Dialect/LLVMIR/canonicalize.mlir (+97-2)
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 5b819485b1be4..3dcc3f1373a5c 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -1898,6 +1898,23 @@ static Type getInsertExtractValueElementType(Type llvmType,
   return llvmType;
 }
 
+/// Extract the element at `index` from `attr` if it is an `ElementsAttr` or
+/// `ArrayAttr`. Returns `nullptr` if `attr` is not one of those types or if the
+/// `index` is out of bounds.
+static Attribute extractElementAt(Attribute attr, size_t index) {
+  if (auto elementsAttr = dyn_cast<ElementsAttr>(attr)) {
+    if (index < static_cast<size_t>(elementsAttr.getNumElements()))
+      return elementsAttr.getValues<Attribute>()[index];
+    return nullptr;
+  }
+  if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
+    if (index < arrayAttr.getValue().size())
+      return arrayAttr[index];
+    return nullptr;
+  }
+  return nullptr;
+}
+
 OpFoldResult LLVM::ExtractValueOp::fold(FoldAdaptor adaptor) {
   if (auto extractValueOp = getContainer().getDefiningOp<ExtractValueOp>()) {
     SmallVector<int64_t, 4> newPos(extractValueOp.getPosition());
@@ -1907,22 +1924,11 @@ OpFoldResult LLVM::ExtractValueOp::fold(FoldAdaptor adaptor) {
     return getResult();
   }
 
-  {
-    DenseElementsAttr constval;
-    matchPattern(getContainer(), m_Constant(&constval));
-    if (constval && constval.getElementType() == getType()) {
-      if (isa<SplatElementsAttr>(constval))
-        return constval.getSplatValue<Attribute>();
-      if (getPosition().size() == 1)
-        return constval.getValues<Attribute>()[getPosition()[0]];
-    }
-  }
-
-  auto insertValueOp = getContainer().getDefiningOp<InsertValueOp>();
+  Operation *container = getContainer().getDefiningOp();
   OpFoldResult result = {};
   ArrayRef<int64_t> extractPos = getPosition();
   bool switchedToInsertedValue = false;
-  while (insertValueOp) {
+  while (auto insertValueOp = dyn_cast_if_present<InsertValueOp>(container)) {
     ArrayRef<int64_t> insertPos = insertValueOp.getPosition();
     auto extractPosSize = extractPos.size();
     auto insertPosSize = insertPos.size();
@@ -1945,7 +1951,7 @@ OpFoldResult LLVM::ExtractValueOp::fold(FoldAdaptor adaptor) {
     // In the above example, %4 is folded to %arg1.
     if (extractPosSize > insertPosSize &&
         extractPos.take_front(insertPosSize) == insertPos) {
-      insertValueOp = insertValueOp.getValue().getDefiningOp<InsertValueOp>();
+      container = insertValueOp.getValue().getDefiningOp();
       extractPos = extractPos.drop_front(insertPosSize);
       switchedToInsertedValue = true;
       continue;
@@ -1975,9 +1981,32 @@ OpFoldResult LLVM::ExtractValueOp::fold(FoldAdaptor adaptor) {
       getContainerMutable().assign(insertValueOp.getContainer());
       result = getResult();
     }
-    insertValueOp = insertValueOp.getContainer().getDefiningOp<InsertValueOp>();
+    container = insertValueOp.getContainer().getDefiningOp();
+  }
+  if (!container)
+    return result;
+
+  Attribute containerAttr;
+  if (!matchPattern(container, m_Constant(&containerAttr)))
+    return nullptr;
+  for (int64_t pos : extractPos) {
+    Attribute attrElement = extractElementAt(containerAttr, pos);
+
+    // It is possible to fail to extract an element from the container and still
+    // fold the operation to a constant. For example:
+    // ```
+    // %container = llvm.mlir.zero : !llvm.struct<(i8, i32)>
+    // %result = llvm.extractvalue %container[0] : !llvm.struct<(i8, i32)>
+    // ```
+    // In this case, `containerAttr` is an `LLVM::ZeroAttr` that does not
+    // contain any nested elements, yet the operation can be folded to a zero
+    // constant.
+    if (!attrElement)
+      return containerAttr;
+
+    containerAttr = attrElement;
   }
-  return result;
+  return containerAttr;
 }
 
 LogicalResult ExtractValueOp::verify() {
diff --git a/mlir/test/Dialect/LLVMIR/canonicalize.mlir b/mlir/test/Dialect/LLVMIR/canonicalize.mlir
index 755e3a3a5fa09..8303afc9eb033 100644
--- a/mlir/test/Dialect/LLVMIR/canonicalize.mlir
+++ b/mlir/test/Dialect/LLVMIR/canonicalize.mlir
@@ -112,10 +112,10 @@ llvm.func @fold_extract_extractvalue(%arr: !llvm.struct<(i64, array<1 x ptr<1>>)
 
 // -----
 
-// CHECK-LABEL: fold_extract_const
+// CHECK-LABEL: fold_extract_const_array
 // CHECK-NOT: extractvalue
 // CHECK:  llvm.mlir.constant(5.000000e-01 : f64)
-llvm.func @fold_extract_const() -> f64 {
+llvm.func @fold_extract_const_array() -> f64 {
   %a = llvm.mlir.constant(dense<[-8.900000e+01, 5.000000e-01]> : tensor<2xf64>) : !llvm.array<2 x f64>
   %b = llvm.extractvalue %a[1] : !llvm.array<2 x f64>
   llvm.return %b : f64
@@ -123,6 +123,17 @@ llvm.func @fold_extract_const() -> f64 {
 
 // -----
 
+// CHECK-LABEL: fold_extract_const_struct
+llvm.func @fold_extract_const_struct() -> i32 {
+  // CHECK-NOT: extractvalue
+  // CHECK: llvm.mlir.constant(2 : i32)
+  %a = llvm.mlir.constant([1 : i16, 2 : i32]) : !llvm.struct<(i16, i32)>
+  %b = llvm.extractvalue %a[1] : !llvm.struct<(i16, i32)>
+  llvm.return %b : i32
+}
+
+// -----
+
 // CHECK-LABEL: fold_extract_splat
 // CHECK-NOT: extractvalue
 // CHECK: llvm.mlir.constant(-8.900000e+01 : f64)
@@ -134,6 +145,90 @@ llvm.func @fold_extract_splat() -> f64 {
 
 // -----
 
+// CHECK-LABEL: fold_extract_splat_nested
+llvm.func @fold_extract_splat_nested() -> i32 {
+  // CHECK-NOT: extractvalue
+  // CHECK: llvm.mlir.constant(1 : i32)
+  %a = llvm.mlir.constant(dense<(0, 1)> : tensor<2xcomplex<i32>>) : !llvm.array<2 x !llvm.struct<(i32, i32)>>
+  %b = llvm.extractvalue %a[1, 1] : !llvm.array<2 x !llvm.struct<(i32, i32)>>
+  llvm.return %b : i32
+}
+
+// -----
+
+// CHECK-LABEL: fold_extract_sparse
+llvm.func @fold_extract_sparse() -> f32 {
+  // CHECK-NOT: extractvalue
+  // CHECK-DAG: %[[C0:.*]] = llvm.mlir.constant(0.000000e+00 : f32)
+  // CHECK-DAG: %[[C42:.*]] = llvm.mlir.constant(4.200000e+01 : f32)
+  %0 = llvm.mlir.constant(sparse<[0], [4.2e+01]> : tensor<4xf32>) : !llvm.array<4 x f32>
+  %1 = llvm.extractvalue %0[0] : !llvm.array<4 x f32>
+  %2 = llvm.extractvalue %0[1] : !llvm.array<4 x f32>
+  // CHECK: llvm.fadd %[[C42]], %[[C0]]
+  %3 = llvm.fadd %1, %2 : f32
+  llvm.return %3 : f32
+}
+
+// -----
+
+// CHECK-LABEL: fold_zero
+llvm.func @fold_zero() -> i32 {
+  // CHECK-NOT: insertvalue
+  // CHECK-NOT: extractvalue
+  // CHECK: %[[ZERO:.*]] = llvm.mlir.zero : i32
+  %0 = llvm.mlir.zero : !llvm.struct<(i16, i32)>
+
+  %1 = llvm.mlir.undef : !llvm.array<2 x !llvm.struct<(i16, i32)>>
+  %2 = llvm.insertvalue %0, %1[0] : !llvm.array<2 x !llvm.struct<(i16, i32)>>
+  %3 = llvm.extractvalue %2[0, 1] : !llvm.array<2 x !llvm.struct<(i16, i32)>>
+  // CHECK: llvm.return %[[ZERO]]
+  llvm.return %3 : i32
+}
+
+// -----
+
+llvm.func @use_struct(!llvm.struct<(i16, i32)>)
+
+// CHECK-LABEL: fold_undef
+llvm.func @fold_undef() -> i32 {
+  // CHECK-NOT: insertvalue
+  // CHECK-NOT: extractvalue
+  // CHECK-DAG: %[[UNDEF_I32:.*]] = llvm.mlir.undef : i32
+  // CHECK-DAG: %[[UNDEF_STRUCT:.*]] = llvm.mlir.undef : !llvm.struct<(i16, i32)>
+  %0 = llvm.mlir.undef : !llvm.struct<(i8, !llvm.struct<(i16, i32)>)>
+
+  %1 = llvm.extractvalue %0[1] : !llvm.struct<(i8, !llvm.struct<(i16, i32)>)>
+  // CHECK: llvm.call @use_struct(%[[UNDEF_STRUCT]])
+  llvm.call @use_struct(%1) : (!llvm.struct<(i16, i32)>) -> ()
+
+  %2 = llvm.extractvalue %0[1, 1] : !llvm.struct<(i8, !llvm.struct<(i16, i32)>)>
+  // CHECK: llvm.return %[[UNDEF_I32]]
+  llvm.return %2 : i32
+}
+
+// -----
+
+llvm.func @use_array(!llvm.array<8 x f32>)
+
+// CHECK-LABEL: fold_poison
+llvm.func @fold_poison() -> f32 {
+  // CHECK-NOT: insertvalue
+  // CHECK-NOT: extractvalue
+  // CHECK-DAG: %[[POISON_F32:.*]] = llvm.mlir.poison : f32
+  // CHECK-DAG: %[[POISON_ARRAY:.*]] = llvm.mlir.poison : !llvm.array<8 x f32>
+  %0 = llvm.mlir.poison : !llvm.array<2 x !llvm.array<8 x f32>>
+
+  %1 = llvm.extractvalue %0[1] : !llvm.array<2 x !llvm.array<8 x f32>>
+  // CHECK: llvm.call @use_array(%[[POISON_ARRAY]])
+  llvm.call @use_array(%1) : (!llvm.array<8 x f32>) -> ()
+
+  %2 = llvm.extractvalue %0[1, 1] : !llvm.array<2 x !llvm.array<8 x f32>>
+  // CHECK: llvm.return %[[POISON_F32]]
+  llvm.return %2 : f32
+}
+
+// -----
+
 // CHECK-LABEL: fold_bitcast
 // CHECK-SAME: %[[ARG:[[:alnum:]]+]]
 // CHECK-NEXT: llvm.return %[[ARG]]

Copy link
Contributor

@Dinistro Dinistro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't remember if we support any kind of nested struct constants, but might be nice to check if this is actually foldable as well (assuming they exit).

Otherwise, LGTM! Thanks for the fix.

Copy link
Contributor

@gysit gysit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for extending this!

Extend the `extractvalue` fold method to support extracting from constant
containers, such as `llvm.mlir.zero`, `llvm.mlir.undef`,
`llvm.mlir.poison`, and `llvm.mlir.constant` holding `ElementsAttr` or
`ArrayAttr`.
return result;

Attribute containerAttr;
if (!matchPattern(container, m_Constant(&containerAttr)))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not following the logic very well: seems to me that we're at the end of a chain of insertValueOps and you're then just looking at the initial constant. However the initial value can have been overridden by one of the insert, can't it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might be misunderstanding your comment, but note that container here is not strictly the initial constant. It is the container at the end of the insertValueOp chain (if such a chain exists). If there isn't a chain, then container corresponds to the initial constant.

@VadimCurca VadimCurca force-pushed the vadimc/extend_extractvalue_folder branch from 63054e5 to f7bb67e Compare December 15, 2025 15:27
Copy link
Contributor

@gysit gysit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the change.

LGTM from my end.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants