-
Notifications
You must be signed in to change notification settings - Fork 15.5k
[MLIR][Vector] Add distribution pattern for vector::ConstantMaskOp
#172268
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-mlir Author: Artem Kroviakov (akroviakov) ChangesThis PR enables Full diff: https://github.com/llvm/llvm-project/pull/172268.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 8b5e950733a22..90d6901089525 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1123,26 +1123,42 @@ struct WarpOpCreateMask : public WarpDistributionPattern {
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
OpOperand *yieldOperand =
- getWarpResult(warpOp, llvm::IsaPred<vector::CreateMaskOp>);
+ getWarpResult(warpOp, (llvm::IsaPred<vector::CreateMaskOp>));
+ if (!yieldOperand)
+ yieldOperand =
+ getWarpResult(warpOp, (llvm::IsaPred<vector::ConstantMaskOp>));
if (!yieldOperand)
return failure();
- auto mask = yieldOperand->get().getDefiningOp<vector::CreateMaskOp>();
+ Operation *mask = yieldOperand->get().getDefiningOp<vector::CreateMaskOp>();
+ if (!mask)
+ mask = yieldOperand->get().getDefiningOp<vector::ConstantMaskOp>();
// Early exit if any values needed for calculating the new mask indices
// are defined inside the warp op.
- if (!llvm::all_of(mask->getOperands(), [&](Value value) {
+ if (mask->getOperands().size() &&
+ !llvm::all_of(mask->getOperands(), [&](Value value) {
return warpOp.isDefinedOutsideOfRegion(value);
}))
return failure();
- Location loc = mask.getLoc();
+ Location loc = mask->getLoc();
unsigned operandIndex = yieldOperand->getOperandNumber();
auto distType = cast<VectorType>(warpOp.getResult(operandIndex).getType());
- VectorType seqType = mask.getVectorType();
+ VectorType seqType = cast<VectorType>(mask->getResult(0).getType());
ArrayRef<int64_t> seqShape = seqType.getShape();
ArrayRef<int64_t> distShape = distType.getShape();
+ SmallVector<Value> materializedOperands;
+ if (auto createMaskOp = dyn_cast<vector::CreateMaskOp>(mask)) {
+ materializedOperands.append(createMaskOp.getOperands().begin(),
+ createMaskOp.getOperands().end());
+ } else if (auto constantMaskOp = dyn_cast<vector::ConstantMaskOp>(mask)) {
+ auto dimSizes = constantMaskOp.getMaskDimSizesAttr().asArrayRef();
+ for (auto dimSize : dimSizes)
+ materializedOperands.push_back(
+ arith::ConstantIndexOp::create(rewriter, loc, dimSize).getResult());
+ }
rewriter.setInsertionPointAfter(warpOp);
@@ -1170,7 +1186,7 @@ struct WarpOpCreateMask : public WarpDistributionPattern {
// mask sizes are always in the range [0, mask_vector_size[i]).
Value maskDimIdx = affine::makeComposedAffineApply(
rewriter, loc, s1 - s0 * distShape[i],
- {delinearizedIds[i], mask.getOperand(i)});
+ {delinearizedIds[i], materializedOperands[i]});
newOperands.push_back(maskDimIdx);
}
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 0cf6dd151e16c..135db02d543ef 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -1779,6 +1779,21 @@ func.func @warp_propagate_unconnected_read_write(%laneid: index, %buffer: memref
// CHECK-DIST-AND-PROP: %[[CST:.+]] = arith.constant dense<2.000000e+00> : vector<4xf32>
// CHECK-DIST-AND-PROP: vector.transfer_read {{.*}} : memref<128xf32>, vector<4xf32>
// CHECK-DIST-AND-PROP: vector.transfer_write %[[CST]], {{.*}} : vector<4xf32>, memref<128xf32>
+// -----
+
+func.func @warp_propagate_constant_mask(%laneid: index) -> vector<1xi1> {
+ %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xi1>) {
+ %1 = vector.constant_mask [1] : vector<32xi1>
+ gpu.yield %1 : vector<32xi1>
+ }
+ return %r : vector<1xi1>
+}
+
+// CHECK-PROP-DAG: #[[$SUB:.*]] = affine_map<()[s0] -> (-s0 + 1)>
+// CHECK-PROP-LABEL: func @warp_propagate_constant_mask
+// CHECK-PROP-SAME: %[[LANEID:.+]]: index
+// CHECK-PROP: %[[MDIST:.+]] = affine.apply #[[$SUB]]()[%[[LANEID]]]
+// CHECK-PROP: vector.create_mask %[[MDIST]] : vector<1xi1>
// -----
@@ -1813,6 +1828,24 @@ func.func @warp_propagate_multi_dim_create_mask(%laneid: index, %m0: index, %m1:
// CHECK-PROP: %[[DISTM0:.+]] = affine.apply #[[$SUBM0]]()[%[[M0]], %[[LANEID]]]
// CHECK-PROP: %[[DISTM1:.+]] = affine.apply #[[$SUBM1]]()[%[[M1]], %[[LANEID]]]
// CHECK-PROP: vector.create_mask %[[DISTM0]], %[[DISTM1]], %[[M2]] : vector<1x2x4xi1>
+// -----
+
+func.func @warp_propagate_multi_dim_constant_mask(%laneid: index) -> vector<1x2x4xi1> {
+ %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1x2x4xi1>) {
+ %1 = vector.constant_mask [1, 1, 2]: vector<16x4x4xi1>
+ gpu.yield %1 : vector<16x4x4xi1>
+ }
+ return %r : vector<1x2x4xi1>
+}
+
+// CHECK-PROP-DAG: #[[$SUBM0:.*]] = affine_map<()[s0] -> (-(s0 floordiv 2) + 1)>
+// CHECK-PROP-DAG: #[[$SUBM1:.*]] = affine_map<()[s0] -> (s0 * -2 + (s0 floordiv 2) * 4 + 1)>
+// CHECK-PROP-LABEL: func @warp_propagate_multi_dim_constant_mask
+// CHECK-PROP-SAME: %[[LANEID:.+]]: index
+// CHECK-PROP: %[[CST2:.+]] = arith.constant 2 : index
+// CHECK-PROP: %[[DISTM0:.+]] = affine.apply #[[$SUBM0]]()[%[[LANEID]]]
+// CHECK-PROP: %[[DISTM1:.+]] = affine.apply #[[$SUBM1]]()[%[[LANEID]]]
+// CHECK-PROP: vector.create_mask %[[DISTM0]], %[[DISTM1]], %[[CST2]] : vector<1x2x4xi1>
// -----
|
|
@llvm/pr-subscribers-mlir-vector Author: Artem Kroviakov (akroviakov) ChangesThis PR enables Full diff: https://github.com/llvm/llvm-project/pull/172268.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 8b5e950733a22..90d6901089525 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1123,26 +1123,42 @@ struct WarpOpCreateMask : public WarpDistributionPattern {
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
OpOperand *yieldOperand =
- getWarpResult(warpOp, llvm::IsaPred<vector::CreateMaskOp>);
+ getWarpResult(warpOp, (llvm::IsaPred<vector::CreateMaskOp>));
+ if (!yieldOperand)
+ yieldOperand =
+ getWarpResult(warpOp, (llvm::IsaPred<vector::ConstantMaskOp>));
if (!yieldOperand)
return failure();
- auto mask = yieldOperand->get().getDefiningOp<vector::CreateMaskOp>();
+ Operation *mask = yieldOperand->get().getDefiningOp<vector::CreateMaskOp>();
+ if (!mask)
+ mask = yieldOperand->get().getDefiningOp<vector::ConstantMaskOp>();
// Early exit if any values needed for calculating the new mask indices
// are defined inside the warp op.
- if (!llvm::all_of(mask->getOperands(), [&](Value value) {
+ if (mask->getOperands().size() &&
+ !llvm::all_of(mask->getOperands(), [&](Value value) {
return warpOp.isDefinedOutsideOfRegion(value);
}))
return failure();
- Location loc = mask.getLoc();
+ Location loc = mask->getLoc();
unsigned operandIndex = yieldOperand->getOperandNumber();
auto distType = cast<VectorType>(warpOp.getResult(operandIndex).getType());
- VectorType seqType = mask.getVectorType();
+ VectorType seqType = cast<VectorType>(mask->getResult(0).getType());
ArrayRef<int64_t> seqShape = seqType.getShape();
ArrayRef<int64_t> distShape = distType.getShape();
+ SmallVector<Value> materializedOperands;
+ if (auto createMaskOp = dyn_cast<vector::CreateMaskOp>(mask)) {
+ materializedOperands.append(createMaskOp.getOperands().begin(),
+ createMaskOp.getOperands().end());
+ } else if (auto constantMaskOp = dyn_cast<vector::ConstantMaskOp>(mask)) {
+ auto dimSizes = constantMaskOp.getMaskDimSizesAttr().asArrayRef();
+ for (auto dimSize : dimSizes)
+ materializedOperands.push_back(
+ arith::ConstantIndexOp::create(rewriter, loc, dimSize).getResult());
+ }
rewriter.setInsertionPointAfter(warpOp);
@@ -1170,7 +1186,7 @@ struct WarpOpCreateMask : public WarpDistributionPattern {
// mask sizes are always in the range [0, mask_vector_size[i]).
Value maskDimIdx = affine::makeComposedAffineApply(
rewriter, loc, s1 - s0 * distShape[i],
- {delinearizedIds[i], mask.getOperand(i)});
+ {delinearizedIds[i], materializedOperands[i]});
newOperands.push_back(maskDimIdx);
}
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 0cf6dd151e16c..135db02d543ef 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -1779,6 +1779,21 @@ func.func @warp_propagate_unconnected_read_write(%laneid: index, %buffer: memref
// CHECK-DIST-AND-PROP: %[[CST:.+]] = arith.constant dense<2.000000e+00> : vector<4xf32>
// CHECK-DIST-AND-PROP: vector.transfer_read {{.*}} : memref<128xf32>, vector<4xf32>
// CHECK-DIST-AND-PROP: vector.transfer_write %[[CST]], {{.*}} : vector<4xf32>, memref<128xf32>
+// -----
+
+func.func @warp_propagate_constant_mask(%laneid: index) -> vector<1xi1> {
+ %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xi1>) {
+ %1 = vector.constant_mask [1] : vector<32xi1>
+ gpu.yield %1 : vector<32xi1>
+ }
+ return %r : vector<1xi1>
+}
+
+// CHECK-PROP-DAG: #[[$SUB:.*]] = affine_map<()[s0] -> (-s0 + 1)>
+// CHECK-PROP-LABEL: func @warp_propagate_constant_mask
+// CHECK-PROP-SAME: %[[LANEID:.+]]: index
+// CHECK-PROP: %[[MDIST:.+]] = affine.apply #[[$SUB]]()[%[[LANEID]]]
+// CHECK-PROP: vector.create_mask %[[MDIST]] : vector<1xi1>
// -----
@@ -1813,6 +1828,24 @@ func.func @warp_propagate_multi_dim_create_mask(%laneid: index, %m0: index, %m1:
// CHECK-PROP: %[[DISTM0:.+]] = affine.apply #[[$SUBM0]]()[%[[M0]], %[[LANEID]]]
// CHECK-PROP: %[[DISTM1:.+]] = affine.apply #[[$SUBM1]]()[%[[M1]], %[[LANEID]]]
// CHECK-PROP: vector.create_mask %[[DISTM0]], %[[DISTM1]], %[[M2]] : vector<1x2x4xi1>
+// -----
+
+func.func @warp_propagate_multi_dim_constant_mask(%laneid: index) -> vector<1x2x4xi1> {
+ %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1x2x4xi1>) {
+ %1 = vector.constant_mask [1, 1, 2]: vector<16x4x4xi1>
+ gpu.yield %1 : vector<16x4x4xi1>
+ }
+ return %r : vector<1x2x4xi1>
+}
+
+// CHECK-PROP-DAG: #[[$SUBM0:.*]] = affine_map<()[s0] -> (-(s0 floordiv 2) + 1)>
+// CHECK-PROP-DAG: #[[$SUBM1:.*]] = affine_map<()[s0] -> (s0 * -2 + (s0 floordiv 2) * 4 + 1)>
+// CHECK-PROP-LABEL: func @warp_propagate_multi_dim_constant_mask
+// CHECK-PROP-SAME: %[[LANEID:.+]]: index
+// CHECK-PROP: %[[CST2:.+]] = arith.constant 2 : index
+// CHECK-PROP: %[[DISTM0:.+]] = affine.apply #[[$SUBM0]]()[%[[LANEID]]]
+// CHECK-PROP: %[[DISTM1:.+]] = affine.apply #[[$SUBM1]]()[%[[LANEID]]]
+// CHECK-PROP: vector.create_mask %[[DISTM0]], %[[DISTM1]], %[[CST2]] : vector<1x2x4xi1>
// -----
|
| /// %1 = vector.create_mask %ub : vector<1xi1> | ||
| struct WarpOpCreateMask : public WarpDistributionPattern { | ||
| using Base::Base; | ||
| LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably function comments need to be tweaked a bit
|
what is the need for merging logic with CreateMask? wouldn't a separate pattern easier to understand and maintain? |
|
The semantics of both ops are the same, the interface of ops is mostly the same (a list of offsets). Both ops are tightly coupled, hence their distributions are as well. It only costs a few lines of trivial code to fit the distribution logic. Does it really require a new (and mostly identical) pattern? |
In that case I suggest using the format used in WGtoSG. (use some template naming). |
charithaintc
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good.
I suggest using template to distinguish the two patterns.
WarpOpMaskOp<OpType>
|
I haven't really worked on distribution, so will leave the reviewing to experts.
Let me just mention that we are considering merging |
This PR enables
vector::ConstantMaskOp(attribute-based indices) distribution by extending the distribution of its SSA-variant siblingvector::CreateMaskOp. Both ops offer equivalent semantics, so we can materialize attributes as SSA values and plug into the existing distribution logic.