diff --git a/src/main/java/org/apache/sysds/hops/OptimizerUtils.java b/src/main/java/org/apache/sysds/hops/OptimizerUtils.java index 9ba3ea3ed77..ae9bf1993f1 100644 --- a/src/main/java/org/apache/sysds/hops/OptimizerUtils.java +++ b/src/main/java/org/apache/sysds/hops/OptimizerUtils.java @@ -334,7 +334,12 @@ public enum MemoryManager { public static boolean AUTO_GPU_CACHE_EVICTION = true; - ////////////////////// + /** + * Boolean specifying if relational algebra rewrites are allowed (e.g. Selection Pushdowns). + */ + public static boolean ALLOW_RA_REWRITES = false; + + ////////////////////// // Optimizer levels // ////////////////////// diff --git a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java index c2602dba510..4fd3234150f 100644 --- a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java +++ b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java @@ -117,6 +117,8 @@ public ProgramRewriter(boolean staticRewrites, boolean dynamicRewrites) _sbRuleSet.add( new RewriteMarkLoopVariablesUpdateInPlace() ); if( LineageCacheConfig.getCompAssRW() ) _sbRuleSet.add( new MarkForLineageReuse() ); + if( OptimizerUtils.ALLOW_RA_REWRITES ) + _sbRuleSet.add( new RewriteRaPushdown() ); _sbRuleSet.add( new RewriteRemoveTransformEncodeMeta() ); _dagRuleSet.add( new RewriteNonScalarPrint() ); } diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteRaPushdown.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteRaPushdown.java new file mode 100644 index 00000000000..eeff97ea153 --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteRaPushdown.java @@ -0,0 +1,191 @@ +package org.apache.sysds.hops.rewrite; + +import org.apache.sysds.common.Types; +import org.apache.sysds.hops.*; +import org.apache.sysds.hops.recompile.Recompiler; +import org.apache.sysds.parser.StatementBlock; +import org.apache.sysds.parser.VariableSet; + +import java.util.ArrayList; +import java.util.List; + +/** + * Rule: Simplify program structure by rewriting relational expressions, + * implemented here: Pushdown of Selections before Join. + */ +public class RewriteRaPushdown extends StatementBlockRewriteRule +{ + @Override + public boolean createsSplitDag() { + return false; + } + + @Override + public List rewriteStatementBlock(StatementBlock sb, ProgramRewriteStatus state) { + ArrayList ret = new ArrayList<>(); + ret.add(sb); + return ret; + } + + @Override + public List rewriteStatementBlocks(List sbs, ProgramRewriteStatus state) { + if (sbs == null || sbs.size() <= 1) + return sbs; + + ArrayList tmpList = new ArrayList<>(sbs); + boolean changed = false; + + // iterate over all SBs including a FuncOp with FuncName m_raJoin + for (int i : findFunctionSb(tmpList, "m_raJoin", 0)){ + StatementBlock sb1 = tmpList.get(i); + FunctionOp joinOp = findFunctionOp(sb1.getHops(), "m_raJoin"); + + // iterate over all following SBs including a FuncOp with FuncName m_raSelection + for (int j : findFunctionSb(tmpList, "m_raSelection", i+1)){ + StatementBlock sb2 = tmpList.get(j); + FunctionOp selOp = findFunctionOp(sb2.getHops(), "m_raSelection"); + + // create deep copy to ensure data consistency + FunctionOp tmpJoinOp = (FunctionOp) Recompiler.deepCopyHopsDag(joinOp); + FunctionOp tmpSelOp = (FunctionOp) Recompiler.deepCopyHopsDag(selOp); + + if (!checkDataDependency(tmpJoinOp, tmpSelOp)){continue;} + + Hop selColHop = tmpSelOp.getInput(1); + long selCol = getConstantSelectionCol(selColHop); + if (selCol <= 0) + continue; + + // collect Variable Sets + VariableSet joinRead = new VariableSet(sb1.variablesRead()); + VariableSet joinUpdated = new VariableSet(sb1.variablesUpdated()); + VariableSet selRead = new VariableSet(sb2.variablesRead()); + + // join inputs: [A, colA, B, colB, method] + long colsLeft = tmpJoinOp.getInput(0).getDataCharacteristics().getCols(); + long colsRight = tmpJoinOp.getInput(2).getDataCharacteristics().getCols(); + if (colsLeft <= 0 || colsRight <= 0) + continue; + + // decide which side of inner join the selection belongs to (A / B) + int selSideIdx; + if (selCol <= colsLeft) { + selSideIdx = 0; + } + else if (selCol <= colsLeft + colsRight) { + selSideIdx = 2; + LiteralOp adjustedColHop = new LiteralOp(selCol - colsLeft); + adjustedColHop.setName(selColHop.getName()); + HopRewriteUtils.replaceChildReference(tmpSelOp, selColHop, adjustedColHop, 1); + } + else { continue; } // invalid column index + + // switch funcOps Output Variables + String joinOutVar = tmpJoinOp.getOutputVariableNames()[0]; + tmpJoinOp.getOutputVariableNames()[0] = tmpSelOp.getOutputVariableNames()[0]; + tmpSelOp.getOutputVariableNames()[0] = joinOutVar; + + // rewire selection to consume the correct join input and adjusted column + Hop newSelInput = tmpJoinOp.getInput().get(selSideIdx); + HopRewriteUtils.replaceChildReference(tmpSelOp, tmpSelOp.getInput().get(0), newSelInput, 0); + + // let the join take selection output instead of raw input + Hop newJoinInput = HopRewriteUtils.createTransientRead(joinOutVar, tmpSelOp); + HopRewriteUtils.replaceChildReference(tmpJoinOp, newSelInput, newJoinInput, selSideIdx); + + //switch StatementBlock-assignments + sb1.getHops().remove(joinOp); + sb1.getHops().add(tmpSelOp); + sb2.getHops().remove(selOp); + sb2.getHops().add(tmpJoinOp); + + // modify SB- variable sets + VariableSet vs = new VariableSet(); + vs.addVariable(joinOutVar, joinUpdated.getVariable(joinOutVar)); + selRead.removeVariables(vs); + selRead.addVariable(newSelInput.getName(), joinRead.getVariable(newSelInput.getName())); + + // selection now reads the original join inputs plus its own metadata + sb1.setReadVariables(selRead); + sb1.setLiveOut(VariableSet.minus(joinUpdated, selRead)); + sb1.setLiveIn(selRead); + sb1.setGen(selRead); + + // join now consumes the selection output and produces the output + sb2.setReadVariables(sb1.liveOut()); + sb2.setGen(sb1.liveOut()); + sb2.setLiveIn(sb1.liveOut()); + + // mark change & increment i by 1 (i+1 = now join-Sb) + changed = true; + i++; + + LOG.debug("Applied rewrite: pushed m_raSelection before m_raJoin (blocks lines " + + sb1.getBeginLine() + "-" + sb1.getEndLine() + " and " + + sb2.getBeginLine() + "-" + sb2.getEndLine() + ")."); + } + } + return changed ? tmpList : sbs; + } + + private List findFunctionSb(List sbs, String functionName, int startIdx) { + List functionSbs = new ArrayList<>(); + + for (int i = startIdx; i < sbs.size(); i++) { + StatementBlock sb = sbs.get(i); + + // easy preconditions + if (!HopRewriteUtils.isLastLevelStatementBlock(sb) || sb.isSplitDag()) { + continue; + } + + // find if StatementBlocks have certain FunctionOp, continue if not found + FunctionOp functionOp = findFunctionOp(sb.getHops(), functionName); + + // if found, add to list + if (functionOp != null) { functionSbs.add(i); } + } + + return functionSbs; + } + + private boolean checkDataDependency(FunctionOp fOut, FunctionOp fIn){ + for (String out : fOut.getOutputVariableNames()) { + for (Hop h : fIn.getInput()) { + if (h.getName().equals(out)){ + return true; + } + } + } + return false; + } + + private FunctionOp findFunctionOp(List roots, String functionName) { + if (roots == null) + return null; + Hop.resetVisitStatus(roots, true); + for (Hop root : roots) { + if (root instanceof FunctionOp funcOp) { + if (funcOp.getFunctionName().equals(functionName)) + { return funcOp; } + } + } + return null; + } + + private long getConstantSelectionCol(Hop selColHop) { + if (selColHop instanceof LiteralOp lit) + return HopRewriteUtils.getIntValueSafe(lit); + + // Handle casted literals (e.g., type propagation inserted casts) + if (selColHop instanceof UnaryOp uop && uop.getOp() == Types.OpOp1.CAST_AS_INT + && uop.getInput().get(0) instanceof LiteralOp lit) + return HopRewriteUtils.getIntValueSafe(lit); + + // If hop is a dataop whose input is a literal, try to fold + if (selColHop instanceof DataOp dop && !dop.getInput().isEmpty() && dop.getInput().get(0) instanceof LiteralOp lit) + return HopRewriteUtils.getIntValueSafe(lit); + + return -1; // unknown at rewrite time + } +} \ No newline at end of file diff --git a/src/test/java/org/apache/sysds/test/functions/rewrite/RewritePushdownRaSelectionTest.java b/src/test/java/org/apache/sysds/test/functions/rewrite/RewritePushdownRaSelectionTest.java new file mode 100644 index 00000000000..602811217ab --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/rewrite/RewritePushdownRaSelectionTest.java @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.rewrite; + +import java.util.HashMap; + +import org.apache.sysds.common.Opcodes; +import org.apache.sysds.common.Types; +import org.apache.sysds.hops.OptimizerUtils; +import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Test; + +public class RewritePushdownRaSelectionTest extends AutomatedTestBase +{ + private static final String TEST_NAME = "RewritePushdownRaSelection"; + private static final String TEST_DIR = "functions/rewrite/"; + private static final String TEST_CLASS_DIR = TEST_DIR + RewritePushdownRaSelectionTest.class.getSimpleName() + "/"; + + private static final double eps = 1e-8; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"result"})); + } + + @Test + public void testRewritePushdownRaSelectionNoRewrite() { + int col = 1; + String op = Opcodes.EQUAL.toString(); + double val = 4.0; + + // Expected output matrix + double[][] Y = { + {4,7,8,4,7,8}, + {4,7,8,4,5,10}, + {4,3,5,4,7,8}, + {4,3,5,4,5,10}, + }; + + testRewritePushdownRaSelection(col, op, val, Y, "nested-loop", false); + } + + @Test + public void testRewritePushdownRaSelection1() { + int col = 1; + String op = Opcodes.EQUAL.toString(); + double val = 4.0; + + // Expected output matrix + double[][] Y = { + {4,7,8,4,7,8}, + {4,7,8,4,5,10}, + {4,3,5,4,7,8}, + {4,3,5,4,5,10}, + }; + + testRewritePushdownRaSelection(col, op, val, Y, "sort-merge", true); + } + + @Test + public void testRewritePushdownRaSelection2() { + int col = 5; + String op = Opcodes.EQUAL.toString(); + double val = 7.0; + + // Expected output matrix + double[][] Y = { + {4,7,8,4,7,8}, + {4,3,5,4,7,8}, + }; + + testRewritePushdownRaSelection(col, op, val, Y, "sort-merge", true); + } + + private void testRewritePushdownRaSelection(int col, String op, double val, double[][] Y, + String method, boolean rewrites) { + + //generate actual dataset and variables + double[][] A = { + {1, 2, 3}, + {4, 7, 8}, + {1, 3, 6}, + {4, 3, 5}, + {5, 8, 9} + }; + double[][] B = { + {1, 2, 9}, + {3, 7, 6}, + {2, 8, 5}, + {4, 7, 8}, + {4, 5, 10} + }; + int colA = 1; + int colB = 1; + + runRewritePushdownRaSelectionTest(A, colA, B, colB, Y, col, op, val, method, rewrites); + } + + + private void runRewritePushdownRaSelectionTest(double [][] A, int colA, double [][] B, int colB, double [][] Y, + int col, String op, double val, String method, boolean rewrites) + { + Types.ExecMode platformOld = setExecMode(Types.ExecMode.SINGLE_NODE); + boolean oldFlag = OptimizerUtils.ALLOW_RA_REWRITES; + + try + { + loadTestConfiguration(getTestConfiguration(TEST_NAME)); + String HOME = SCRIPT_DIR + TEST_DIR; + + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + programArgs = new String[]{"-explain", "hops", "-args", + input("A"), String.valueOf(colA), input("B"), + String.valueOf(colB), String.valueOf(col), op, String.valueOf(val), method, output("result") }; + writeInputMatrixWithMTD("A", A, true); + writeInputMatrixWithMTD("B", B, true); + + OptimizerUtils.ALLOW_RA_REWRITES = rewrites; + + // run dmlScript + runTest(null); + + //compare matrices + HashMap dmlfile = readDMLMatrixFromOutputDir("result"); + HashMap expectedOutput = TestUtils.convert2DDoubleArrayToHashMap(Y); + TestUtils.compareMatrices(dmlfile, expectedOutput, eps, "Stat-DML", "Expected"); + } + finally { + rtplatform = platformOld; + OptimizerUtils.ALLOW_RA_REWRITES = oldFlag; + } + } +} diff --git a/src/test/scripts/functions/rewrite/RewritePushdownRaSelection.dml b/src/test/scripts/functions/rewrite/RewritePushdownRaSelection.dml new file mode 100644 index 00000000000..ef149acb35c --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewritePushdownRaSelection.dml @@ -0,0 +1,40 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +A= read($1) +colA = as.integer($2) +B = read($3) +colB = as.integer($4) +op = $6 + +C = raJoin(A, colA, B, colB, $8); +result = raSelection(C, $5, op, $7); + +# the above will be rewritten into: +# +# C = raSelection(A, col, op, val); +# result = raJoin(C, colA, B, colB, method); +# or (depending on col): +# C = raSelection(B, (col - A.cols), op, val); +# result = raJoin(A, colA, C, colB, method); + +write(result, $9); +print(toString(result)) \ No newline at end of file