Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/main/java/org/apache/sysds/hops/OptimizerUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,12 @@ public enum MemoryManager {

public static boolean AUTO_GPU_CACHE_EVICTION = true;

//////////////////////
/**
* Boolean specifying if relational algebra rewrites are allowed (e.g. Selection Pushdowns).
*/
public static boolean ALLOW_RA_REWRITES = false;

//////////////////////
// Optimizer levels //
//////////////////////

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ public ProgramRewriter(boolean staticRewrites, boolean dynamicRewrites)
_sbRuleSet.add( new RewriteMarkLoopVariablesUpdateInPlace() );
if( LineageCacheConfig.getCompAssRW() )
_sbRuleSet.add( new MarkForLineageReuse() );
if( OptimizerUtils.ALLOW_RA_REWRITES )
_sbRuleSet.add( new RewriteRaPushdown() );
_sbRuleSet.add( new RewriteRemoveTransformEncodeMeta() );
_dagRuleSet.add( new RewriteNonScalarPrint() );
}
Expand Down
191 changes: 191 additions & 0 deletions src/main/java/org/apache/sysds/hops/rewrite/RewriteRaPushdown.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
package org.apache.sysds.hops.rewrite;

import org.apache.sysds.common.Types;
import org.apache.sysds.hops.*;
import org.apache.sysds.hops.recompile.Recompiler;
import org.apache.sysds.parser.StatementBlock;
import org.apache.sysds.parser.VariableSet;

import java.util.ArrayList;
import java.util.List;

/**
* Rule: Simplify program structure by rewriting relational expressions,
* implemented here: Pushdown of Selections before Join.
*/
public class RewriteRaPushdown extends StatementBlockRewriteRule
{
@Override
public boolean createsSplitDag() {
return false;
}

@Override
public List<StatementBlock> rewriteStatementBlock(StatementBlock sb, ProgramRewriteStatus state) {
ArrayList<StatementBlock> ret = new ArrayList<>();
ret.add(sb);
return ret;
}

@Override
public List<StatementBlock> rewriteStatementBlocks(List<StatementBlock> sbs, ProgramRewriteStatus state) {
if (sbs == null || sbs.size() <= 1)
return sbs;

ArrayList<StatementBlock> tmpList = new ArrayList<>(sbs);
boolean changed = false;

// iterate over all SBs including a FuncOp with FuncName m_raJoin
for (int i : findFunctionSb(tmpList, "m_raJoin", 0)){
StatementBlock sb1 = tmpList.get(i);
FunctionOp joinOp = findFunctionOp(sb1.getHops(), "m_raJoin");

// iterate over all following SBs including a FuncOp with FuncName m_raSelection
for (int j : findFunctionSb(tmpList, "m_raSelection", i+1)){
StatementBlock sb2 = tmpList.get(j);
FunctionOp selOp = findFunctionOp(sb2.getHops(), "m_raSelection");

// create deep copy to ensure data consistency
FunctionOp tmpJoinOp = (FunctionOp) Recompiler.deepCopyHopsDag(joinOp);
FunctionOp tmpSelOp = (FunctionOp) Recompiler.deepCopyHopsDag(selOp);

if (!checkDataDependency(tmpJoinOp, tmpSelOp)){continue;}

Hop selColHop = tmpSelOp.getInput(1);
long selCol = getConstantSelectionCol(selColHop);
if (selCol <= 0)
continue;

// collect Variable Sets
VariableSet joinRead = new VariableSet(sb1.variablesRead());
VariableSet joinUpdated = new VariableSet(sb1.variablesUpdated());
VariableSet selRead = new VariableSet(sb2.variablesRead());

// join inputs: [A, colA, B, colB, method]
long colsLeft = tmpJoinOp.getInput(0).getDataCharacteristics().getCols();
long colsRight = tmpJoinOp.getInput(2).getDataCharacteristics().getCols();
if (colsLeft <= 0 || colsRight <= 0)
continue;

// decide which side of inner join the selection belongs to (A / B)
int selSideIdx;
if (selCol <= colsLeft) {
selSideIdx = 0;
}
else if (selCol <= colsLeft + colsRight) {
selSideIdx = 2;
LiteralOp adjustedColHop = new LiteralOp(selCol - colsLeft);
adjustedColHop.setName(selColHop.getName());
HopRewriteUtils.replaceChildReference(tmpSelOp, selColHop, adjustedColHop, 1);
}
else { continue; } // invalid column index

// switch funcOps Output Variables
String joinOutVar = tmpJoinOp.getOutputVariableNames()[0];
tmpJoinOp.getOutputVariableNames()[0] = tmpSelOp.getOutputVariableNames()[0];
tmpSelOp.getOutputVariableNames()[0] = joinOutVar;

// rewire selection to consume the correct join input and adjusted column
Hop newSelInput = tmpJoinOp.getInput().get(selSideIdx);
HopRewriteUtils.replaceChildReference(tmpSelOp, tmpSelOp.getInput().get(0), newSelInput, 0);

// let the join take selection output instead of raw input
Hop newJoinInput = HopRewriteUtils.createTransientRead(joinOutVar, tmpSelOp);
HopRewriteUtils.replaceChildReference(tmpJoinOp, newSelInput, newJoinInput, selSideIdx);

//switch StatementBlock-assignments
sb1.getHops().remove(joinOp);
sb1.getHops().add(tmpSelOp);
sb2.getHops().remove(selOp);
sb2.getHops().add(tmpJoinOp);

// modify SB- variable sets
VariableSet vs = new VariableSet();
vs.addVariable(joinOutVar, joinUpdated.getVariable(joinOutVar));
selRead.removeVariables(vs);
selRead.addVariable(newSelInput.getName(), joinRead.getVariable(newSelInput.getName()));

// selection now reads the original join inputs plus its own metadata
sb1.setReadVariables(selRead);
sb1.setLiveOut(VariableSet.minus(joinUpdated, selRead));
sb1.setLiveIn(selRead);
sb1.setGen(selRead);

// join now consumes the selection output and produces the output
sb2.setReadVariables(sb1.liveOut());
sb2.setGen(sb1.liveOut());
sb2.setLiveIn(sb1.liveOut());

// mark change & increment i by 1 (i+1 = now join-Sb)
changed = true;
i++;

LOG.debug("Applied rewrite: pushed m_raSelection before m_raJoin (blocks lines "
+ sb1.getBeginLine() + "-" + sb1.getEndLine() + " and "
+ sb2.getBeginLine() + "-" + sb2.getEndLine() + ").");
}
}
return changed ? tmpList : sbs;
}

private List<Integer> findFunctionSb(List<StatementBlock> sbs, String functionName, int startIdx) {
List<Integer> functionSbs = new ArrayList<>();

for (int i = startIdx; i < sbs.size(); i++) {
StatementBlock sb = sbs.get(i);

// easy preconditions
if (!HopRewriteUtils.isLastLevelStatementBlock(sb) || sb.isSplitDag()) {
continue;
}

// find if StatementBlocks have certain FunctionOp, continue if not found
FunctionOp functionOp = findFunctionOp(sb.getHops(), functionName);

// if found, add to list
if (functionOp != null) { functionSbs.add(i); }
}

return functionSbs;
}

private boolean checkDataDependency(FunctionOp fOut, FunctionOp fIn){
for (String out : fOut.getOutputVariableNames()) {
for (Hop h : fIn.getInput()) {
if (h.getName().equals(out)){
return true;
}
}
}
return false;
}

private FunctionOp findFunctionOp(List<Hop> roots, String functionName) {
if (roots == null)
return null;
Hop.resetVisitStatus(roots, true);
for (Hop root : roots) {
if (root instanceof FunctionOp funcOp) {
if (funcOp.getFunctionName().equals(functionName))
{ return funcOp; }
}
}
return null;
}

private long getConstantSelectionCol(Hop selColHop) {
if (selColHop instanceof LiteralOp lit)
return HopRewriteUtils.getIntValueSafe(lit);

// Handle casted literals (e.g., type propagation inserted casts)
if (selColHop instanceof UnaryOp uop && uop.getOp() == Types.OpOp1.CAST_AS_INT
&& uop.getInput().get(0) instanceof LiteralOp lit)
return HopRewriteUtils.getIntValueSafe(lit);

// If hop is a dataop whose input is a literal, try to fold
if (selColHop instanceof DataOp dop && !dop.getInput().isEmpty() && dop.getInput().get(0) instanceof LiteralOp lit)
return HopRewriteUtils.getIntValueSafe(lit);

return -1; // unknown at rewrite time
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.sysds.test.functions.rewrite;

import java.util.HashMap;

import org.apache.sysds.common.Opcodes;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
import org.junit.Test;

public class RewritePushdownRaSelectionTest extends AutomatedTestBase
{
private static final String TEST_NAME = "RewritePushdownRaSelection";
private static final String TEST_DIR = "functions/rewrite/";
private static final String TEST_CLASS_DIR = TEST_DIR + RewritePushdownRaSelectionTest.class.getSimpleName() + "/";

private static final double eps = 1e-8;

@Override
public void setUp() {
TestUtils.clearAssertionInformation();
addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"result"}));
}

@Test
public void testRewritePushdownRaSelectionNoRewrite() {
int col = 1;
String op = Opcodes.EQUAL.toString();
double val = 4.0;

// Expected output matrix
double[][] Y = {
{4,7,8,4,7,8},
{4,7,8,4,5,10},
{4,3,5,4,7,8},
{4,3,5,4,5,10},
};

testRewritePushdownRaSelection(col, op, val, Y, "nested-loop", false);
}

@Test
public void testRewritePushdownRaSelection1() {
int col = 1;
String op = Opcodes.EQUAL.toString();
double val = 4.0;

// Expected output matrix
double[][] Y = {
{4,7,8,4,7,8},
{4,7,8,4,5,10},
{4,3,5,4,7,8},
{4,3,5,4,5,10},
};

testRewritePushdownRaSelection(col, op, val, Y, "sort-merge", true);
}

@Test
public void testRewritePushdownRaSelection2() {
int col = 5;
String op = Opcodes.EQUAL.toString();
double val = 7.0;

// Expected output matrix
double[][] Y = {
{4,7,8,4,7,8},
{4,3,5,4,7,8},
};

testRewritePushdownRaSelection(col, op, val, Y, "sort-merge", true);
}

private void testRewritePushdownRaSelection(int col, String op, double val, double[][] Y,
String method, boolean rewrites) {

//generate actual dataset and variables
double[][] A = {
{1, 2, 3},
{4, 7, 8},
{1, 3, 6},
{4, 3, 5},
{5, 8, 9}
};
double[][] B = {
{1, 2, 9},
{3, 7, 6},
{2, 8, 5},
{4, 7, 8},
{4, 5, 10}
};
int colA = 1;
int colB = 1;

runRewritePushdownRaSelectionTest(A, colA, B, colB, Y, col, op, val, method, rewrites);
}


private void runRewritePushdownRaSelectionTest(double [][] A, int colA, double [][] B, int colB, double [][] Y,
int col, String op, double val, String method, boolean rewrites)
{
Types.ExecMode platformOld = setExecMode(Types.ExecMode.SINGLE_NODE);
boolean oldFlag = OptimizerUtils.ALLOW_RA_REWRITES;

try
{
loadTestConfiguration(getTestConfiguration(TEST_NAME));
String HOME = SCRIPT_DIR + TEST_DIR;

fullDMLScriptName = HOME + TEST_NAME + ".dml";
programArgs = new String[]{"-explain", "hops", "-args",
input("A"), String.valueOf(colA), input("B"),
String.valueOf(colB), String.valueOf(col), op, String.valueOf(val), method, output("result") };
writeInputMatrixWithMTD("A", A, true);
writeInputMatrixWithMTD("B", B, true);

OptimizerUtils.ALLOW_RA_REWRITES = rewrites;

// run dmlScript
runTest(null);

//compare matrices
HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromOutputDir("result");
HashMap<CellIndex, Double> expectedOutput = TestUtils.convert2DDoubleArrayToHashMap(Y);
TestUtils.compareMatrices(dmlfile, expectedOutput, eps, "Stat-DML", "Expected");
}
finally {
rtplatform = platformOld;
OptimizerUtils.ALLOW_RA_REWRITES = oldFlag;
}
}
}
Loading