diff --git a/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h b/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h index 288740561..10cb02c17 100644 --- a/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h +++ b/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h @@ -844,6 +844,25 @@ class QCProgramBuilder final : public OpBuilder { QCProgramBuilder& ctrl(ValueRange controls, const std::function& body); + /** + * @brief Apply an inverse (i.e., adjoint) operation. + * + * @param body Function that builds the body containing the operation to + * invert + * @return Reference to this builder for method chaining + * + * @par Example: + * ```c++ + * builder.inv([&](auto& b) { b.s(q0); }); + * ``` + * ```mlir + * qc.inv { + * qc.s %q0 : !qc.qubit + * } + * ``` + */ + QCProgramBuilder& inv(const std::function& body); + //===--------------------------------------------------------------------===// // Deallocation //===--------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/QC/IR/QCOps.td b/mlir/include/mlir/Dialect/QC/IR/QCOps.td index 7ed127dbf..e513ab4c5 100644 --- a/mlir/include/mlir/Dialect/QC/IR/QCOps.td +++ b/mlir/include/mlir/Dialect/QC/IR/QCOps.td @@ -973,4 +973,48 @@ def CtrlOp : QCOp<"ctrl", let hasVerifier = 1; } +def InvOp : QCOp<"inv", + traits = [ + UnitaryOpInterface, + SingleBlockImplicitTerminator<"::mlir::qc::YieldOp">, + RecursiveMemoryEffects + ]> { + let summary = "Invert a unitary operation"; + let description = [{ + A modifier operation that inverts the unitary operation defined in its body + region. + + Example: + ```mlir + qc.inv { + qc.s %q0 : !qc.qubit + } + ``` + }]; + + let regions = (region SizedRegion<1>:$body); + let assemblyFormat = "$body attr-dict"; + + let extraClassDeclaration = [{ + [[nodiscard]] UnitaryOpInterface getBodyUnitary(); + size_t getNumQubits(); + size_t getNumTargets(); + size_t getNumControls(); + Value getQubit(size_t i); + Value getTarget(size_t i); + Value getControl(size_t i); + size_t getNumParams(); + Value getParameter(size_t i); + static StringRef getBaseSymbol() { return "inv"; } + }]; + + let builders = [ + OpBuilder<(ins "UnitaryOpInterface":$bodyUnitary)>, + OpBuilder<(ins "const std::function&":$bodyBuilder)> + ]; + + let hasCanonicalizer = 1; + let hasVerifier = 1; +} + #endif // QC_OPS diff --git a/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h b/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h index cd5f4ae89..2f965728c 100644 --- a/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h +++ b/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h @@ -1003,6 +1003,30 @@ class QCOProgramBuilder final : public OpBuilder { ctrl(ValueRange controls, ValueRange targets, const std::function& body); + /** + * @brief Apply an inverse operation + * + * @param targets Target qubits + * @param body Function that builds the body containing the target operation + * @return Output target qubits + * + * @par Example: + * ```c++ + * targets_out = builder.inv(q0_in, [&](auto& b) { + * auto q0_res = b.s(q0_in); + * return {q0_res}; + * }); + * ``` + * ```mlir + * %targets_out = qco.inv %q0_in { + * %q0_res = qco.s %q0_in : !qco.qubit -> !qco.qubit + * qco.yield %q0_res + * } : {!qco.qubit} -> {!qco.qubit} + * ``` + */ + ValueRange inv(ValueRange targets, + const std::function& body); + //===--------------------------------------------------------------------===// // Deallocation //===--------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/QCO/IR/QCOOps.td b/mlir/include/mlir/Dialect/QCO/IR/QCOOps.td index 69f89eb5e..8c96ca5d1 100644 --- a/mlir/include/mlir/Dialect/QCO/IR/QCOOps.td +++ b/mlir/include/mlir/Dialect/QCO/IR/QCOOps.td @@ -1106,4 +1106,68 @@ def CtrlOp : QCOOp<"ctrl", traits = let hasVerifier = 1; } +def InvOp : QCOOp<"inv", traits = + [ + UnitaryOpInterface, + SameOperandsAndResultType, + SameOperandsAndResultShape, + SingleBlock, + RecursiveMemoryEffects + ]> { + let summary = "Invert a unitary operation"; + let description = [{ + A modifier operation that inverts the unitary operation defined in its body + region. The operation takes a variadic number of target qubits as inputs and + produces corresponding output qubits. + + Example: + ```mlir + %targets_out = qco.inv %targets_in { + %targets_res = qco.s %targets_in : !qco.qubit -> !qco.qubit + qco.yield %targets_res : !qco.qubit + } : {!qco.qubit} -> {!qco.qubit} + ``` + }]; + + let arguments = (ins Arg, "the target qubits", [MemRead]>:$targets_in); + let results = (outs Variadic:$targets_out); + let regions = (region SizedRegion<1>:$body); + let assemblyFormat = [{ + $targets_in + $body attr-dict `:` + `{` type($targets_in) `}` + `->` + `{` type($targets_out) `}` + }]; + + let extraClassDeclaration = [{ + UnitaryOpInterface getBodyUnitary(); + size_t getNumQubits(); + size_t getNumTargets(); + size_t getNumControls(); + Value getInputQubit(size_t i); + Value getOutputQubit(size_t i); + Value getInputTarget(size_t i); + Value getOutputTarget(size_t i); + Value getInputControl(size_t i); + Value getOutputControl(size_t i); + Value getInputForOutput(Value output); + Value getOutputForInput(Value input); + size_t getNumParams(); + Value getParameter(size_t i); + static StringRef getBaseSymbol() { return "inv"; } + }]; + + let builders = [ + OpBuilder<(ins "ValueRange":$targets), [{ + build($_builder, $_state, targets.getTypes(), targets); + }]>, + OpBuilder<(ins "ValueRange":$targets, "UnitaryOpInterface":$bodyUnitary)>, + OpBuilder<(ins "ValueRange":$targets, "const std::function&":$bodyBuilder)> + ]; + + let hasCanonicalizer = 1; + let hasVerifier = 1; +} + #endif // QCOOPS diff --git a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp index 9e89791b0..58018259c 100644 --- a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp +++ b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp @@ -787,6 +787,43 @@ struct ConvertQCOCtrlOp final : OpConversionPattern { } }; +/** + * @brief Converts qco.inv to qc.inv + * + * @par Example: + * ```mlir + * %targets_out = qco.inv %q0_in { + * %q0_res = qco.s %q0_in : !qco.qubit -> !qco.qubit + * qco.yield %q0_res + * } : {!qco.qubit} -> {!qco.qubit} + * ``` + * is converted to + * ```mlir + * qc.inv { + * qc.s %q0 : !qc.qubit + * } + * ``` + */ +struct ConvertQCOInvOp final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(qco::InvOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + // Create qc.inv operation + auto qcOp = qc::InvOp::create(rewriter, op->getLoc()); + + // Clone body region from QCO to QC + auto& dstRegion = qcOp.getBody(); + rewriter.cloneRegionBefore(op.getBody(), dstRegion, dstRegion.end()); + + // Replace the output qubits with the same QC references + rewriter.replaceOp(op, adaptor.getOperands()); + + return success(); + } +}; + /** * @brief Converts qco.yield to qc.yield * @@ -854,18 +891,18 @@ struct QCOToQC final : impl::QCOToQCBase { // Register operation conversion patterns // Note: No state tracking needed - OpAdaptors handle type conversion - patterns.add( - typeConverter, context); + patterns + .add(typeConverter, context); // Conversion of qco types in func.func signatures // Note: This currently has limitations with signature changes diff --git a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp index 224afb3c9..b76268c9a 100644 --- a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp +++ b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp @@ -1135,6 +1135,68 @@ struct ConvertQCCtrlOp final : StatefulOpConversionPattern { } }; +/** + * @brief Converts qc.inv to qco.inv + * + * @par Example: + * ```mlir + * qc.inv { + * qc.s %q0 + * } + * ``` + * is converted to + * ```mlir + * %targets_out = qco.inv %q0_in { + * %q0_res = qco.s %q0_in : !qco.qubit -> !qco.qubit + * qco.yield %q0_res + * } : {!qco.qubit} -> {!qco.qubit} + * ``` + */ +struct ConvertQCInvOp final : StatefulOpConversionPattern { + using StatefulOpConversionPattern::StatefulOpConversionPattern; + + LogicalResult + matchAndRewrite(qc::InvOp op, OpAdaptor /*adaptor*/, + ConversionPatternRewriter& rewriter) const override { + auto& state = getState(); + auto& qubitMap = state.qubitMap; + + // Get QCO targets from state map + const auto numTargets = op.getNumTargets(); + SmallVector qcoTargets; + qcoTargets.reserve(numTargets); + for (size_t i = 0; i < numTargets; ++i) { + const auto& qcTarget = op.getTarget(i); + assert(qubitMap.contains(qcTarget) && "QC qubit not found"); + const auto& qcoTarget = qubitMap[qcTarget]; + qcoTargets.push_back(qcoTarget); + } + + // Create qco.inv + auto qcoOp = qco::InvOp::create(rewriter, op.getLoc(), qcoTargets); + + // Update state map + if (state.inCtrlOp == 0) { + const auto targetsOut = qcoOp.getTargetsOut(); + for (size_t i = 0; i < numTargets; ++i) { + const auto& qcTarget = op.getTarget(i); + qubitMap[qcTarget] = targetsOut[i]; + } + } + + // Update modifier information + state.inCtrlOp++; + state.targetsIn.try_emplace(state.inCtrlOp, qcoTargets); + + // Clone body region from QC to QCO + auto& dstRegion = qcoOp.getBody(); + rewriter.cloneRegionBefore(op.getBody(), dstRegion, dstRegion.end()); + + rewriter.eraseOp(op); + return success(); + } +}; + /** * @brief Converts qc.yield to qco.yield * @@ -1216,7 +1278,8 @@ struct QCToQCO final : impl::QCToQCOBase { ConvertQCDCXOp, ConvertQCECROp, ConvertQCRXXOp, ConvertQCRYYOp, ConvertQCRZXOp, ConvertQCRZZOp, ConvertQCXXPlusYYOp, ConvertQCXXMinusYYOp, ConvertQCBarrierOp, ConvertQCCtrlOp, - ConvertQCYieldOp>(typeConverter, context, &state); + ConvertQCInvOp, ConvertQCYieldOp>(typeConverter, context, + &state); // Conversion of qc types in func.func signatures // Note: This currently has limitations with signature diff --git a/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp b/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp index f9139835f..ccaed2386 100644 --- a/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp +++ b/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp @@ -425,6 +425,13 @@ QCProgramBuilder::ctrl(ValueRange controls, return *this; } +QCProgramBuilder& +QCProgramBuilder::inv(const std::function& body) { + checkFinalized(); + InvOp::create(*this, loc, body); + return *this; +} + //===----------------------------------------------------------------------===// // Deallocation //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp b/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp new file mode 100644 index 000000000..abb7ce980 --- /dev/null +++ b/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp @@ -0,0 +1,132 @@ +/* + * Copyright (c) 2023 - 2025 Chair for Design Automation, TUM + * Copyright (c) 2025 Munich Quantum Software Company GmbH + * All rights reserved. + * + * SPDX-License-Identifier: MIT + * + * Licensed under the MIT License + */ + +#include "mlir/Dialect/QC/IR/QCDialect.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace mlir; +using namespace mlir::qc; + +namespace { + +/** + * @brief Cancel nested inverse modifiers, i.e., `inv(inv(x)) => x`. + */ +struct CancelNestedInv final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(InvOp invOp, + PatternRewriter& rewriter) const override { + auto innerUnitary = invOp.getBodyUnitary(); + auto innerInvOp = llvm::dyn_cast(innerUnitary.getOperation()); + if (!innerInvOp) { + return failure(); + } + + auto innerInnerUnitary = innerInvOp.getBodyUnitary(); + auto* clonedOp = rewriter.clone(*innerInnerUnitary.getOperation()); + rewriter.replaceOp(invOp, clonedOp->getResults()); + + return success(); + } +}; + +} // namespace + +UnitaryOpInterface InvOp::getBodyUnitary() { + return llvm::dyn_cast(&getBody().front().front()); +} + +size_t InvOp::getNumQubits() { return getNumTargets() + getNumControls(); } + +size_t InvOp::getNumTargets() { return getBodyUnitary().getNumTargets(); } + +size_t InvOp::getNumControls() { return getBodyUnitary().getNumControls(); } + +Value InvOp::getQubit(const size_t i) { return getBodyUnitary().getQubit(i); } + +Value InvOp::getTarget(const size_t i) { return getBodyUnitary().getTarget(i); } + +Value InvOp::getControl(const size_t i) { + return getBodyUnitary().getControl(i); +} + +size_t InvOp::getNumParams() { return getBodyUnitary().getNumParams(); } + +Value InvOp::getParameter(const size_t i) { + return getBodyUnitary().getParameter(i); +} + +void InvOp::build(OpBuilder& odsBuilder, OperationState& odsState, + UnitaryOpInterface bodyUnitary) { + const OpBuilder::InsertionGuard guard(odsBuilder); + auto* region = odsState.addRegion(); + auto& block = region->emplaceBlock(); + + // Move the unitary op into the block + odsBuilder.setInsertionPointToStart(&block); + odsBuilder.clone(*bodyUnitary.getOperation()); + YieldOp::create(odsBuilder, odsState.location); +} + +void InvOp::build(OpBuilder& odsBuilder, OperationState& odsState, + const std::function& bodyBuilder) { + const OpBuilder::InsertionGuard guard(odsBuilder); + auto* region = odsState.addRegion(); + auto& block = region->emplaceBlock(); + + odsBuilder.setInsertionPointToStart(&block); + bodyBuilder(odsBuilder); + YieldOp::create(odsBuilder, odsState.location); +} + +LogicalResult InvOp::verify() { + auto& block = getBody().front(); + if (block.getOperations().size() != 2) { + return emitOpError("body region must have exactly two operations"); + } + if (!llvm::isa(block.front())) { + return emitOpError( + "first operation in body region must be a unitary operation"); + } + if (!llvm::isa(block.back())) { + return emitOpError( + "second operation in body region must be a yield operation"); + } + + llvm::SmallPtrSet uniqueQubits; + auto bodyUnitary = getBodyUnitary(); + const auto numQubits = bodyUnitary.getNumQubits(); + for (size_t i = 0; i < numQubits; i++) { + if (!uniqueQubits.insert(bodyUnitary.getQubit(i)).second) { + return emitOpError("duplicate qubit found"); + } + } + + if (llvm::isa(bodyUnitary.getOperation())) { + return emitOpError("BarrierOp cannot be inverted"); + } + + return success(); +} + +void InvOp::getCanonicalizationPatterns(RewritePatternSet& results, + MLIRContext* context) { + results.add(context); +} diff --git a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp index a3fa2123b..98033ed5c 100644 --- a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp +++ b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp @@ -596,6 +596,22 @@ std::pair QCOProgramBuilder::ctrl( return {controlsOut, targetsOut}; } +ValueRange QCOProgramBuilder::inv( + ValueRange targets, + const std::function& body) { + checkFinalized(); + + auto invOp = InvOp::create(*this, loc, targets, body); + + // Update tracking + const auto& targetsOut = invOp.getTargetsOut(); + for (const auto& [target, targetOut] : llvm::zip(targets, targetsOut)) { + updateQubitTracking(target, targetOut); + } + + return targetsOut; +} + //===----------------------------------------------------------------------===// // Deallocation //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp b/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp new file mode 100644 index 000000000..3176bae01 --- /dev/null +++ b/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp @@ -0,0 +1,189 @@ +/* + * Copyright (c) 2023 - 2025 Chair for Design Automation, TUM + * Copyright (c) 2025 Munich Quantum Software Company GmbH + * All rights reserved. + * + * SPDX-License-Identifier: MIT + * + * Licensed under the MIT License + */ + +#include "mlir/Dialect/QCO/IR/QCODialect.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace mlir; +using namespace mlir::qco; + +namespace { + +/** + * @brief Cancel nested inverse modifiers, i.e., `inv(inv(x)) => x`. + */ +struct CancelNestedInv final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(InvOp op, + PatternRewriter& rewriter) const override { + auto innerUnitary = op.getBodyUnitary(); + auto innerInvOp = llvm::dyn_cast(innerUnitary.getOperation()); + if (!innerInvOp) { + return failure(); + } + + // Remove both inverse operations + auto innerInnerUnitary = innerInvOp.getBodyUnitary(); + auto* clonedOp = rewriter.clone(*innerInnerUnitary.getOperation()); + rewriter.replaceOp(op, clonedOp->getResults()); + + return success(); + } +}; + +} // namespace + +UnitaryOpInterface InvOp::getBodyUnitary() { + return llvm::dyn_cast(&getBody().front().front()); +} + +size_t InvOp::getNumQubits() { return getNumTargets() + getNumControls(); } + +size_t InvOp::getNumTargets() { return getTargetsIn().size(); } + +size_t InvOp::getNumControls() { return getBodyUnitary().getNumControls(); } + +Value InvOp::getInputQubit(const size_t i) { + return getBodyUnitary().getInputQubit(i); +} + +Value InvOp::getOutputQubit(const size_t i) { + return getBodyUnitary().getOutputQubit(i); +} + +Value InvOp::getInputTarget(const size_t i) { + if (i >= getNumTargets()) { + llvm::reportFatalUsageError("Target index out of bounds"); + } + return getTargetsIn()[i]; +} + +Value InvOp::getOutputTarget(const size_t i) { + if (i >= getNumTargets()) { + llvm::reportFatalUsageError("Target index out of bounds"); + } + return getTargetsOut()[i]; +} + +Value InvOp::getInputControl(const size_t i) { + return getBodyUnitary().getInputControl(i); +} + +Value InvOp::getOutputControl(const size_t i) { + return getBodyUnitary().getOutputControl(i); +} + +Value InvOp::getInputForOutput(Value output) { + for (size_t i = 0; i < getNumTargets(); ++i) { + if (output == getTargetsOut()[i]) { + return getTargetsIn()[i]; + } + } + llvm::reportFatalUsageError("Given qubit is not an output of the operation"); +} + +Value InvOp::getOutputForInput(Value input) { + for (size_t i = 0; i < getNumTargets(); ++i) { + if (input == getTargetsIn()[i]) { + return getTargetsOut()[i]; + } + } + llvm::reportFatalUsageError("Given qubit is not an input of the operation"); +} + +size_t InvOp::getNumParams() { return getBodyUnitary().getNumParams(); } + +Value InvOp::getParameter(const size_t i) { + return getBodyUnitary().getParameter(i); +} + +void InvOp::build(OpBuilder& odsBuilder, OperationState& odsState, + const ValueRange targets, UnitaryOpInterface bodyUnitary) { + build(odsBuilder, odsState, targets); + auto& block = odsState.regions.front()->emplaceBlock(); + + // Move the unitary op into the block + const OpBuilder::InsertionGuard guard(odsBuilder); + odsBuilder.setInsertionPointToStart(&block); + auto* op = odsBuilder.clone(*bodyUnitary.getOperation()); + YieldOp::create(odsBuilder, odsState.location, op->getResults()); +} + +void InvOp::build( + OpBuilder& odsBuilder, OperationState& odsState, const ValueRange targets, + const std::function& bodyBuilder) { + build(odsBuilder, odsState, targets); + auto& block = odsState.regions.front()->emplaceBlock(); + + // Move the unitary op into the block + const OpBuilder::InsertionGuard guard(odsBuilder); + odsBuilder.setInsertionPointToStart(&block); + auto targetsOut = bodyBuilder(odsBuilder, targets); + YieldOp::create(odsBuilder, odsState.location, targetsOut); +} + +LogicalResult InvOp::verify() { + auto& block = getBody().front(); + if (block.getOperations().size() != 2) { + return emitOpError("body region must have exactly two operations"); + } + if (!llvm::isa(block.front())) { + return emitOpError( + "first operation in body region must be a unitary operation"); + } + if (!llvm::isa(block.back())) { + return emitOpError( + "second operation in body region must be a yield operation"); + } + if (block.back().getNumOperands() != getNumTargets()) { + return emitOpError("yield operation must yield ") + << getNumTargets() << " values, but found " + << block.back().getNumOperands(); + } + + SmallPtrSet uniqueQubitsIn; + auto bodyUnitary = getBodyUnitary(); + const auto numQubits = bodyUnitary.getNumQubits(); + for (size_t i = 0; i < numQubits; i++) { + if (!uniqueQubitsIn.insert(bodyUnitary.getInputQubit(i)).second) { + return emitOpError("duplicate qubit found"); + } + } + SmallPtrSet uniqueQubitsOut; + for (size_t i = 0; i < numQubits; i++) { + if (!uniqueQubitsOut.insert(bodyUnitary.getOutputQubit(i)).second) { + return emitOpError("duplicate qubit found"); + } + } + + if (llvm::isa(bodyUnitary.getOperation())) { + return emitOpError("BarrierOp cannot be inverted"); + } + + return success(); +} + +void InvOp::getCanonicalizationPatterns(RewritePatternSet& results, + MLIRContext* context) { + results.add(context); +}