Skip to content

Commit

Permalink
[MINOR] Code cleanups in rewrites and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mboehm7 committed Dec 13, 2024
1 parent 2dcd822 commit e705f89
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ private static Hop removeUnnecessaryRightIndexing(Hop parent, Hop hi, int pos)
{
if( HopRewriteUtils.isUnnecessaryRightIndexing(hi) && !hi.isScalar() ) {
//remove unnecessary right indexing
Hop input = hi.getInput().get(0);
Hop input = hi.getInput(0);
HopRewriteUtils.replaceChildReference(parent, hi, input, pos);
HopRewriteUtils.cleanupUnreferenced(hi);
hi = input;
Expand All @@ -258,8 +258,8 @@ private static Hop removeEmptyLeftIndexing(Hop parent, Hop hi, int pos)
{
if( hi instanceof LeftIndexingOp && hi.getDataType() == DataType.MATRIX ) //left indexing op
{
Hop input1 = hi.getInput().get(0); //lhs matrix
Hop input2 = hi.getInput().get(1); //rhs matrix
Hop input1 = hi.getInput(0); //lhs matrix
Hop input2 = hi.getInput(1); //rhs matrix

if( input1.getNnz()==0 //nnz original known and empty
&& input2.getNnz()==0 ) //nnz input known and empty
Expand All @@ -271,7 +271,7 @@ private static Hop removeEmptyLeftIndexing(Hop parent, Hop hi, int pos)
hi = hnew;

LOG.debug("Applied removeEmptyLeftIndexing");
}
}
}

return hi;
Expand All @@ -281,19 +281,19 @@ private static Hop removeUnnecessaryLeftIndexing(Hop parent, Hop hi, int pos)
{
if( hi instanceof LeftIndexingOp ) //left indexing op
{
Hop input = hi.getInput().get(1); //rhs matrix/frame
Hop input = hi.getInput(1); //rhs matrix/frame

if( HopRewriteUtils.isEqualSize(hi, input) ) //equal dims
{
//equal dims of left indexing input and output -> no need for indexing

//remove unnecessary right indexing
//remove unnecessary right indexing
HopRewriteUtils.replaceChildReference(parent, hi, input, pos);
HopRewriteUtils.cleanupUnreferenced(hi);
hi = input;

LOG.debug("Applied removeUnnecessaryLeftIndexing");
}
}
}

return hi;
Expand All @@ -306,15 +306,15 @@ private static Hop fuseLeftIndexingChainToAppend(Hop parent, Hop hi, int pos)
//pattern1: X[,1]=A; X[,2]=B -> X=cbind(A,B); matrix / frame
if( hi instanceof LeftIndexingOp //first lix
&& HopRewriteUtils.isFullColumnIndexing((LeftIndexingOp)hi)
&& hi.getInput().get(0) instanceof LeftIndexingOp //second lix
&& hi.getInput(0) instanceof LeftIndexingOp //second lix
&& HopRewriteUtils.isFullColumnIndexing((LeftIndexingOp)hi.getInput().get(0))
&& hi.getInput().get(0).getParent().size()==1 //first lix is single consumer
&& hi.getInput().get(0).getInput().get(0).getDim2() == 2 ) //two column matrix
&& hi.getInput(0).getParent().size()==1 //first lix is single consumer
&& hi.getInput(0).getInput(0).getDim2() == 2 ) //two column matrix
{
Hop input2 = hi.getInput().get(1); //rhs matrix
Hop pred2 = hi.getInput().get(4); //cl=cu
Hop input1 = hi.getInput().get(0).getInput().get(1); //lhs matrix
Hop pred1 = hi.getInput().get(0).getInput().get(4); //cl=cu
Hop input2 = hi.getInput(1); //rhs matrix
Hop pred2 = hi.getInput(4); //cl=cu
Hop input1 = hi.getInput(0).getInput(1); //lhs matrix
Hop pred1 = hi.getInput(0).getInput(4); //cl=cu

if( pred1 instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp)pred1)==1
&& pred2 instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp)pred2)==2
Expand All @@ -332,15 +332,15 @@ private static Hop fuseLeftIndexingChainToAppend(Hop parent, Hop hi, int pos)
//pattern1: X[1,]=A; X[2,]=B -> X=rbind(A,B)
if( !applied && hi instanceof LeftIndexingOp //first lix
&& HopRewriteUtils.isFullRowIndexing((LeftIndexingOp)hi)
&& hi.getInput().get(0) instanceof LeftIndexingOp //second lix
&& hi.getInput(0) instanceof LeftIndexingOp //second lix
&& HopRewriteUtils.isFullRowIndexing((LeftIndexingOp)hi.getInput().get(0))
&& hi.getInput().get(0).getParent().size()==1 //first lix is single consumer
&& hi.getInput().get(0).getInput().get(0).getDim1() == 2 ) //two column matrix
&& hi.getInput(0).getParent().size()==1 //first lix is single consumer
&& hi.getInput(0).getInput(0).getDim1() == 2 ) //two column matrix
{
Hop input2 = hi.getInput().get(1); //rhs matrix
Hop pred2 = hi.getInput().get(2); //rl=ru
Hop input1 = hi.getInput().get(0).getInput().get(1); //lhs matrix
Hop pred1 = hi.getInput().get(0).getInput().get(2); //rl=ru
Hop input2 = hi.getInput(1); //rhs matrix
Hop pred2 = hi.getInput(2); //rl=ru
Hop input1 = hi.getInput(0).getInput(1); //lhs matrix
Hop pred1 = hi.getInput(0).getInput(2); //rl=ru

if( pred1 instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp)pred1)==1
&& pred2 instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp)pred2)==2
Expand All @@ -364,19 +364,19 @@ private static Hop removeUnnecessaryCumulativeOp(Hop parent, Hop hi, int pos)
{
if( hi instanceof UnaryOp && ((UnaryOp)hi).isCumulativeUnaryOperation() )
{
Hop input = hi.getInput().get(0); //input matrix
Hop input = hi.getInput(0); //input matrix

if( HopRewriteUtils.isDimsKnown(input) //dims input known
&& input.getDim1()==1 ) //1 row
{
OpOp1 op = ((UnaryOp)hi).getOp();

//remove unnecessary unary cumsum operator
//remove unnecessary unary cumsum operator
HopRewriteUtils.replaceChildReference(parent, hi, input, pos);
hi = input;

LOG.debug("Applied removeUnnecessaryCumulativeOp: "+op);
}
}
}

return hi;
Expand Down Expand Up @@ -413,27 +413,27 @@ private static Hop removeUnnecessaryOuterProduct(Hop parent, Hop hi, int pos)
if( hi instanceof BinaryOp ) //binary cell operation
{
OpOp2 bop = ((BinaryOp)hi).getOp();
Hop left = hi.getInput().get(0);
Hop right = hi.getInput().get(1);
Hop left = hi.getInput(0);
Hop right = hi.getInput(1);

//check for matrix-vector column replication: (A + b %*% ones) -> (A + b)
if( HopRewriteUtils.isMatrixMultiply(right) //matrix mult with datagen
&& HopRewriteUtils.isDataGenOpWithConstantValue(right.getInput().get(1), 1)
&& right.getInput().get(0).getDim2() == 1 ) //column vector for mv binary
&& right.getInput(0).getDim2() == 1 ) //column vector for mv binary
{
//remove unnecessary outer product
HopRewriteUtils.replaceChildReference(hi, right, right.getInput().get(0), 1 );
HopRewriteUtils.replaceChildReference(hi, right, right.getInput(0), 1 );
HopRewriteUtils.cleanupUnreferenced(right);

LOG.debug("Applied removeUnnecessaryOuterProduct1 (line "+right.getBeginLine()+")");
}
//check for matrix-vector row replication: (A + ones %*% b) -> (A + b)
else if( HopRewriteUtils.isMatrixMultiply(right) //matrix mult with datagen
&& HopRewriteUtils.isDataGenOpWithConstantValue(right.getInput().get(0), 1)
&& right.getInput().get(1).getDim1() == 1 ) //row vector for mv binary
&& HopRewriteUtils.isDataGenOpWithConstantValue(right.getInput(0), 1)
&& right.getInput(1).getDim1() == 1 ) //row vector for mv binary
{
//remove unnecessary outer product
HopRewriteUtils.replaceChildReference(hi, right, right.getInput().get(1), 1 );
HopRewriteUtils.replaceChildReference(hi, right, right.getInput(1), 1 );
HopRewriteUtils.cleanupUnreferenced(right);

LOG.debug("Applied removeUnnecessaryOuterProduct2 (line "+right.getBeginLine()+")");
Expand All @@ -442,11 +442,11 @@ else if( HopRewriteUtils.isMatrixMultiply(right) //matrix mult with datagen
else if(HopRewriteUtils.isValidOuterBinaryOp(bop)
&& HopRewriteUtils.isMatrixMultiply(left)
&& HopRewriteUtils.isDataGenOpWithConstantValue(left.getInput().get(1), 1)
&& (left.getInput().get(0).getDim2() == 1 //outer product
|| left.getInput().get(1).getDim1() == 1)
&& (left.getInput(0).getDim2() == 1 //outer product
|| left.getInput(1).getDim1() == 1)
&& left.getDim1() != 1 && right.getDim1() == 1 ) //outer vector binary
{
Hop hnew = HopRewriteUtils.createBinary(left.getInput().get(0), right, bop, true);
Hop hnew = HopRewriteUtils.createBinary(left.getInput(0), right, bop, true);
HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
HopRewriteUtils.cleanupUnreferenced(hi);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@

import org.junit.Assert;
import org.junit.Test;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types.ExecMode;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.common.Types.ExecType;
Expand Down Expand Up @@ -74,16 +73,7 @@ public void testMatrixMultChainOptRewritesSP() {

private void testRewriteMatrixMultChainOp(String testname, boolean rewrites, ExecType et)
{
ExecMode platformOld = rtplatform;
switch( et ){
case SPARK: rtplatform = ExecMode.SPARK; break;
default: rtplatform = ExecMode.HYBRID; break;
}

boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
if( rtplatform == ExecMode.SPARK || rtplatform == ExecMode.HYBRID )
DMLScript.USE_LOCAL_SPARK_CONFIG = true;

ExecMode platformOld = setExecMode(et);
boolean rewritesOld = OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES;
OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewrites;

Expand Down Expand Up @@ -126,8 +116,7 @@ private void testRewriteMatrixMultChainOp(String testname, boolean rewrites, Exe
}
finally {
OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewritesOld;
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
resetExecMode(platformOld);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@

import org.junit.Assert;
import org.junit.Test;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types.ExecMode;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.common.Types.ExecType;
Expand Down Expand Up @@ -73,16 +72,7 @@ public void testMatrixMultChainOptRewritesSP() {

private void testRewriteMatrixMultChainOp(String testname, boolean rewrites, ExecType et)
{
ExecMode platformOld = rtplatform;
switch( et ){
case SPARK: rtplatform = ExecMode.SPARK; break;
default: rtplatform = ExecMode.HYBRID; break;
}

boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
if( rtplatform == ExecMode.SPARK || rtplatform == ExecMode.HYBRID )
DMLScript.USE_LOCAL_SPARK_CONFIG = true;

ExecMode platformOld = setExecMode(et);
boolean rewritesOld = OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES;
OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewrites;

Expand Down Expand Up @@ -119,8 +109,7 @@ private void testRewriteMatrixMultChainOp(String testname, boolean rewrites, Exe
}
finally {
OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewritesOld;
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
resetExecMode(platformOld);
}
}
}

0 comments on commit e705f89

Please sign in to comment.