Skip to content

Conversation

@akroviakov
Copy link
Contributor

This PR enables vector::ConstantMaskOp (attribute-based indices) distribution by extending the distribution of its SSA-variant sibling vector::CreateMaskOp. Both ops offer equivalent semantics, so we can materialize attributes as SSA values and plug into the existing distribution logic.

@llvmbot
Copy link
Member

llvmbot commented Dec 15, 2025

@llvm/pr-subscribers-mlir

Author: Artem Kroviakov (akroviakov)

Changes

This PR enables vector::ConstantMaskOp (attribute-based indices) distribution by extending the distribution of its SSA-variant sibling vector::CreateMaskOp. Both ops offer equivalent semantics, so we can materialize attributes as SSA values and plug into the existing distribution logic.


Full diff: https://github.com/llvm/llvm-project/pull/172268.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp (+22-6)
  • (modified) mlir/test/Dialect/Vector/vector-warp-distribute.mlir (+33)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 8b5e950733a22..90d6901089525 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1123,26 +1123,42 @@ struct WarpOpCreateMask : public WarpDistributionPattern {
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
     OpOperand *yieldOperand =
-        getWarpResult(warpOp, llvm::IsaPred<vector::CreateMaskOp>);
+        getWarpResult(warpOp, (llvm::IsaPred<vector::CreateMaskOp>));
+    if (!yieldOperand)
+      yieldOperand =
+          getWarpResult(warpOp, (llvm::IsaPred<vector::ConstantMaskOp>));
     if (!yieldOperand)
       return failure();
 
-    auto mask = yieldOperand->get().getDefiningOp<vector::CreateMaskOp>();
+    Operation *mask = yieldOperand->get().getDefiningOp<vector::CreateMaskOp>();
+    if (!mask)
+      mask = yieldOperand->get().getDefiningOp<vector::ConstantMaskOp>();
 
     // Early exit if any values needed for calculating the new mask indices
     // are defined inside the warp op.
-    if (!llvm::all_of(mask->getOperands(), [&](Value value) {
+    if (mask->getOperands().size() &&
+        !llvm::all_of(mask->getOperands(), [&](Value value) {
           return warpOp.isDefinedOutsideOfRegion(value);
         }))
       return failure();
 
-    Location loc = mask.getLoc();
+    Location loc = mask->getLoc();
     unsigned operandIndex = yieldOperand->getOperandNumber();
 
     auto distType = cast<VectorType>(warpOp.getResult(operandIndex).getType());
-    VectorType seqType = mask.getVectorType();
+    VectorType seqType = cast<VectorType>(mask->getResult(0).getType());
     ArrayRef<int64_t> seqShape = seqType.getShape();
     ArrayRef<int64_t> distShape = distType.getShape();
+    SmallVector<Value> materializedOperands;
+    if (auto createMaskOp = dyn_cast<vector::CreateMaskOp>(mask)) {
+      materializedOperands.append(createMaskOp.getOperands().begin(),
+                                  createMaskOp.getOperands().end());
+    } else if (auto constantMaskOp = dyn_cast<vector::ConstantMaskOp>(mask)) {
+      auto dimSizes = constantMaskOp.getMaskDimSizesAttr().asArrayRef();
+      for (auto dimSize : dimSizes)
+        materializedOperands.push_back(
+            arith::ConstantIndexOp::create(rewriter, loc, dimSize).getResult());
+    }
 
     rewriter.setInsertionPointAfter(warpOp);
 
@@ -1170,7 +1186,7 @@ struct WarpOpCreateMask : public WarpDistributionPattern {
       // mask sizes are always in the range [0, mask_vector_size[i]).
       Value maskDimIdx = affine::makeComposedAffineApply(
           rewriter, loc, s1 - s0 * distShape[i],
-          {delinearizedIds[i], mask.getOperand(i)});
+          {delinearizedIds[i], materializedOperands[i]});
       newOperands.push_back(maskDimIdx);
     }
 
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 0cf6dd151e16c..135db02d543ef 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -1779,6 +1779,21 @@ func.func @warp_propagate_unconnected_read_write(%laneid: index, %buffer: memref
 //       CHECK-DIST-AND-PROP:   %[[CST:.+]] = arith.constant dense<2.000000e+00> : vector<4xf32>
 //       CHECK-DIST-AND-PROP:   vector.transfer_read {{.*}} : memref<128xf32>, vector<4xf32>
 //       CHECK-DIST-AND-PROP:   vector.transfer_write %[[CST]], {{.*}} : vector<4xf32>, memref<128xf32>
+// -----
+
+func.func @warp_propagate_constant_mask(%laneid: index) -> vector<1xi1> {
+  %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xi1>) {
+    %1 = vector.constant_mask [1] : vector<32xi1>
+    gpu.yield %1 : vector<32xi1>
+  }
+  return %r : vector<1xi1>
+}
+
+//   CHECK-PROP-DAG: #[[$SUB:.*]] = affine_map<()[s0] -> (-s0 + 1)>
+// CHECK-PROP-LABEL: func @warp_propagate_constant_mask
+//  CHECK-PROP-SAME: %[[LANEID:.+]]: index
+//       CHECK-PROP:   %[[MDIST:.+]] = affine.apply #[[$SUB]]()[%[[LANEID]]]
+//       CHECK-PROP:   vector.create_mask %[[MDIST]] : vector<1xi1>
 
 // -----
 
@@ -1813,6 +1828,24 @@ func.func @warp_propagate_multi_dim_create_mask(%laneid: index, %m0: index, %m1:
 //       CHECK-PROP:   %[[DISTM0:.+]] = affine.apply #[[$SUBM0]]()[%[[M0]], %[[LANEID]]]
 //       CHECK-PROP:   %[[DISTM1:.+]] = affine.apply #[[$SUBM1]]()[%[[M1]], %[[LANEID]]]
 //       CHECK-PROP:   vector.create_mask %[[DISTM0]], %[[DISTM1]], %[[M2]] : vector<1x2x4xi1>
+// -----
+
+func.func @warp_propagate_multi_dim_constant_mask(%laneid: index) -> vector<1x2x4xi1> {
+  %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1x2x4xi1>) {
+    %1 = vector.constant_mask [1, 1, 2]: vector<16x4x4xi1>
+    gpu.yield %1 : vector<16x4x4xi1>
+  }
+  return %r : vector<1x2x4xi1>
+}
+
+//   CHECK-PROP-DAG: #[[$SUBM0:.*]] = affine_map<()[s0] -> (-(s0 floordiv 2) + 1)>
+//   CHECK-PROP-DAG: #[[$SUBM1:.*]] = affine_map<()[s0] -> (s0 * -2 + (s0 floordiv 2) * 4 + 1)>
+// CHECK-PROP-LABEL: func @warp_propagate_multi_dim_constant_mask
+//  CHECK-PROP-SAME: %[[LANEID:.+]]: index
+//       CHECK-PROP:   %[[CST2:.+]] = arith.constant 2 : index
+//       CHECK-PROP:   %[[DISTM0:.+]] = affine.apply #[[$SUBM0]]()[%[[LANEID]]]
+//       CHECK-PROP:   %[[DISTM1:.+]] = affine.apply #[[$SUBM1]]()[%[[LANEID]]]
+//       CHECK-PROP:   vector.create_mask %[[DISTM0]], %[[DISTM1]], %[[CST2]] : vector<1x2x4xi1>
 
 // -----
 

@llvmbot
Copy link
Member

llvmbot commented Dec 15, 2025

@llvm/pr-subscribers-mlir-vector

Author: Artem Kroviakov (akroviakov)

Changes

This PR enables vector::ConstantMaskOp (attribute-based indices) distribution by extending the distribution of its SSA-variant sibling vector::CreateMaskOp. Both ops offer equivalent semantics, so we can materialize attributes as SSA values and plug into the existing distribution logic.


Full diff: https://github.com/llvm/llvm-project/pull/172268.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp (+22-6)
  • (modified) mlir/test/Dialect/Vector/vector-warp-distribute.mlir (+33)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 8b5e950733a22..90d6901089525 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1123,26 +1123,42 @@ struct WarpOpCreateMask : public WarpDistributionPattern {
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
     OpOperand *yieldOperand =
-        getWarpResult(warpOp, llvm::IsaPred<vector::CreateMaskOp>);
+        getWarpResult(warpOp, (llvm::IsaPred<vector::CreateMaskOp>));
+    if (!yieldOperand)
+      yieldOperand =
+          getWarpResult(warpOp, (llvm::IsaPred<vector::ConstantMaskOp>));
     if (!yieldOperand)
       return failure();
 
-    auto mask = yieldOperand->get().getDefiningOp<vector::CreateMaskOp>();
+    Operation *mask = yieldOperand->get().getDefiningOp<vector::CreateMaskOp>();
+    if (!mask)
+      mask = yieldOperand->get().getDefiningOp<vector::ConstantMaskOp>();
 
     // Early exit if any values needed for calculating the new mask indices
     // are defined inside the warp op.
-    if (!llvm::all_of(mask->getOperands(), [&](Value value) {
+    if (mask->getOperands().size() &&
+        !llvm::all_of(mask->getOperands(), [&](Value value) {
           return warpOp.isDefinedOutsideOfRegion(value);
         }))
       return failure();
 
-    Location loc = mask.getLoc();
+    Location loc = mask->getLoc();
     unsigned operandIndex = yieldOperand->getOperandNumber();
 
     auto distType = cast<VectorType>(warpOp.getResult(operandIndex).getType());
-    VectorType seqType = mask.getVectorType();
+    VectorType seqType = cast<VectorType>(mask->getResult(0).getType());
     ArrayRef<int64_t> seqShape = seqType.getShape();
     ArrayRef<int64_t> distShape = distType.getShape();
+    SmallVector<Value> materializedOperands;
+    if (auto createMaskOp = dyn_cast<vector::CreateMaskOp>(mask)) {
+      materializedOperands.append(createMaskOp.getOperands().begin(),
+                                  createMaskOp.getOperands().end());
+    } else if (auto constantMaskOp = dyn_cast<vector::ConstantMaskOp>(mask)) {
+      auto dimSizes = constantMaskOp.getMaskDimSizesAttr().asArrayRef();
+      for (auto dimSize : dimSizes)
+        materializedOperands.push_back(
+            arith::ConstantIndexOp::create(rewriter, loc, dimSize).getResult());
+    }
 
     rewriter.setInsertionPointAfter(warpOp);
 
@@ -1170,7 +1186,7 @@ struct WarpOpCreateMask : public WarpDistributionPattern {
       // mask sizes are always in the range [0, mask_vector_size[i]).
       Value maskDimIdx = affine::makeComposedAffineApply(
           rewriter, loc, s1 - s0 * distShape[i],
-          {delinearizedIds[i], mask.getOperand(i)});
+          {delinearizedIds[i], materializedOperands[i]});
       newOperands.push_back(maskDimIdx);
     }
 
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 0cf6dd151e16c..135db02d543ef 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -1779,6 +1779,21 @@ func.func @warp_propagate_unconnected_read_write(%laneid: index, %buffer: memref
 //       CHECK-DIST-AND-PROP:   %[[CST:.+]] = arith.constant dense<2.000000e+00> : vector<4xf32>
 //       CHECK-DIST-AND-PROP:   vector.transfer_read {{.*}} : memref<128xf32>, vector<4xf32>
 //       CHECK-DIST-AND-PROP:   vector.transfer_write %[[CST]], {{.*}} : vector<4xf32>, memref<128xf32>
+// -----
+
+func.func @warp_propagate_constant_mask(%laneid: index) -> vector<1xi1> {
+  %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xi1>) {
+    %1 = vector.constant_mask [1] : vector<32xi1>
+    gpu.yield %1 : vector<32xi1>
+  }
+  return %r : vector<1xi1>
+}
+
+//   CHECK-PROP-DAG: #[[$SUB:.*]] = affine_map<()[s0] -> (-s0 + 1)>
+// CHECK-PROP-LABEL: func @warp_propagate_constant_mask
+//  CHECK-PROP-SAME: %[[LANEID:.+]]: index
+//       CHECK-PROP:   %[[MDIST:.+]] = affine.apply #[[$SUB]]()[%[[LANEID]]]
+//       CHECK-PROP:   vector.create_mask %[[MDIST]] : vector<1xi1>
 
 // -----
 
@@ -1813,6 +1828,24 @@ func.func @warp_propagate_multi_dim_create_mask(%laneid: index, %m0: index, %m1:
 //       CHECK-PROP:   %[[DISTM0:.+]] = affine.apply #[[$SUBM0]]()[%[[M0]], %[[LANEID]]]
 //       CHECK-PROP:   %[[DISTM1:.+]] = affine.apply #[[$SUBM1]]()[%[[M1]], %[[LANEID]]]
 //       CHECK-PROP:   vector.create_mask %[[DISTM0]], %[[DISTM1]], %[[M2]] : vector<1x2x4xi1>
+// -----
+
+func.func @warp_propagate_multi_dim_constant_mask(%laneid: index) -> vector<1x2x4xi1> {
+  %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1x2x4xi1>) {
+    %1 = vector.constant_mask [1, 1, 2]: vector<16x4x4xi1>
+    gpu.yield %1 : vector<16x4x4xi1>
+  }
+  return %r : vector<1x2x4xi1>
+}
+
+//   CHECK-PROP-DAG: #[[$SUBM0:.*]] = affine_map<()[s0] -> (-(s0 floordiv 2) + 1)>
+//   CHECK-PROP-DAG: #[[$SUBM1:.*]] = affine_map<()[s0] -> (s0 * -2 + (s0 floordiv 2) * 4 + 1)>
+// CHECK-PROP-LABEL: func @warp_propagate_multi_dim_constant_mask
+//  CHECK-PROP-SAME: %[[LANEID:.+]]: index
+//       CHECK-PROP:   %[[CST2:.+]] = arith.constant 2 : index
+//       CHECK-PROP:   %[[DISTM0:.+]] = affine.apply #[[$SUBM0]]()[%[[LANEID]]]
+//       CHECK-PROP:   %[[DISTM1:.+]] = affine.apply #[[$SUBM1]]()[%[[LANEID]]]
+//       CHECK-PROP:   vector.create_mask %[[DISTM0]], %[[DISTM1]], %[[CST2]] : vector<1x2x4xi1>
 
 // -----
 

/// %1 = vector.create_mask %ub : vector<1xi1>
struct WarpOpCreateMask : public WarpDistributionPattern {
using Base::Base;
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably function comments need to be tweaked a bit

@charithaintc
Copy link
Contributor

what is the need for merging logic with CreateMask? wouldn't a separate pattern easier to understand and maintain?

@akroviakov
Copy link
Contributor Author

The semantics of both ops are the same, the interface of ops is mostly the same (a list of offsets). Both ops are tightly coupled, hence their distributions are as well. It only costs a few lines of trivial code to fit the distribution logic. Does it really require a new (and mostly identical) pattern?

@charithaintc
Copy link
Contributor

The semantics of both ops are the same, the interface of ops is mostly the same (a list of offsets). Both ops are tightly coupled, hence their distributions are as well. It only costs a few lines of trivial code to fit the distribution logic. Does it really require a new (and mostly identical) pattern?

In that case I suggest using the format used in WGtoSG. (use some template naming).

using WgToSgVectorConstantMaskOp = WgToSgVectorMaskOp<vector::ConstantMaskOp>;
using WgToSgVectorCreateMaskOp = WgToSgVectorMaskOp<vector::CreateMaskOp>;

Copy link
Contributor

@charithaintc charithaintc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good.

I suggest using template to distinguish the two patterns.

WarpOpMaskOp<OpType>

@banach-space
Copy link
Contributor

I haven't really worked on distribution, so will leave the reviewing to experts.

what is the need for merging logic with CreateMask? wouldn't a separate pattern easier to understand and maintain?

Let me just mention that we are considering merging vector.create_mask and vector.constant_mask (unless we find good reasons to keep those two separate). Keeping the logic unified is highly desirable :)

@akroviakov akroviakov merged commit aba8ebb into llvm:main Dec 16, 2025
10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants