diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java index c9a97450911..15207e87b5b 100644 --- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java +++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java @@ -243,7 +243,7 @@ private static Hop removeUnnecessaryRightIndexing(Hop parent, Hop hi, int pos) { if( HopRewriteUtils.isUnnecessaryRightIndexing(hi) && !hi.isScalar() ) { //remove unnecessary right indexing - Hop input = hi.getInput().get(0); + Hop input = hi.getInput(0); HopRewriteUtils.replaceChildReference(parent, hi, input, pos); HopRewriteUtils.cleanupUnreferenced(hi); hi = input; @@ -258,8 +258,8 @@ private static Hop removeEmptyLeftIndexing(Hop parent, Hop hi, int pos) { if( hi instanceof LeftIndexingOp && hi.getDataType() == DataType.MATRIX ) //left indexing op { - Hop input1 = hi.getInput().get(0); //lhs matrix - Hop input2 = hi.getInput().get(1); //rhs matrix + Hop input1 = hi.getInput(0); //lhs matrix + Hop input2 = hi.getInput(1); //rhs matrix if( input1.getNnz()==0 //nnz original known and empty && input2.getNnz()==0 ) //nnz input known and empty @@ -271,7 +271,7 @@ private static Hop removeEmptyLeftIndexing(Hop parent, Hop hi, int pos) hi = hnew; LOG.debug("Applied removeEmptyLeftIndexing"); - } + } } return hi; @@ -281,19 +281,19 @@ private static Hop removeUnnecessaryLeftIndexing(Hop parent, Hop hi, int pos) { if( hi instanceof LeftIndexingOp ) //left indexing op { - Hop input = hi.getInput().get(1); //rhs matrix/frame + Hop input = hi.getInput(1); //rhs matrix/frame if( HopRewriteUtils.isEqualSize(hi, input) ) //equal dims { //equal dims of left indexing input and output -> no need for indexing - //remove unnecessary right indexing + //remove unnecessary right indexing HopRewriteUtils.replaceChildReference(parent, hi, input, pos); HopRewriteUtils.cleanupUnreferenced(hi); hi = input; LOG.debug("Applied removeUnnecessaryLeftIndexing"); - } + } } return hi; @@ -306,15 +306,15 @@ private static Hop fuseLeftIndexingChainToAppend(Hop parent, Hop hi, int pos) //pattern1: X[,1]=A; X[,2]=B -> X=cbind(A,B); matrix / frame if( hi instanceof LeftIndexingOp //first lix && HopRewriteUtils.isFullColumnIndexing((LeftIndexingOp)hi) - && hi.getInput().get(0) instanceof LeftIndexingOp //second lix + && hi.getInput(0) instanceof LeftIndexingOp //second lix && HopRewriteUtils.isFullColumnIndexing((LeftIndexingOp)hi.getInput().get(0)) - && hi.getInput().get(0).getParent().size()==1 //first lix is single consumer - && hi.getInput().get(0).getInput().get(0).getDim2() == 2 ) //two column matrix + && hi.getInput(0).getParent().size()==1 //first lix is single consumer + && hi.getInput(0).getInput(0).getDim2() == 2 ) //two column matrix { - Hop input2 = hi.getInput().get(1); //rhs matrix - Hop pred2 = hi.getInput().get(4); //cl=cu - Hop input1 = hi.getInput().get(0).getInput().get(1); //lhs matrix - Hop pred1 = hi.getInput().get(0).getInput().get(4); //cl=cu + Hop input2 = hi.getInput(1); //rhs matrix + Hop pred2 = hi.getInput(4); //cl=cu + Hop input1 = hi.getInput(0).getInput(1); //lhs matrix + Hop pred1 = hi.getInput(0).getInput(4); //cl=cu if( pred1 instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp)pred1)==1 && pred2 instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp)pred2)==2 @@ -332,15 +332,15 @@ private static Hop fuseLeftIndexingChainToAppend(Hop parent, Hop hi, int pos) //pattern1: X[1,]=A; X[2,]=B -> X=rbind(A,B) if( !applied && hi instanceof LeftIndexingOp //first lix && HopRewriteUtils.isFullRowIndexing((LeftIndexingOp)hi) - && hi.getInput().get(0) instanceof LeftIndexingOp //second lix + && hi.getInput(0) instanceof LeftIndexingOp //second lix && HopRewriteUtils.isFullRowIndexing((LeftIndexingOp)hi.getInput().get(0)) - && hi.getInput().get(0).getParent().size()==1 //first lix is single consumer - && hi.getInput().get(0).getInput().get(0).getDim1() == 2 ) //two column matrix + && hi.getInput(0).getParent().size()==1 //first lix is single consumer + && hi.getInput(0).getInput(0).getDim1() == 2 ) //two column matrix { - Hop input2 = hi.getInput().get(1); //rhs matrix - Hop pred2 = hi.getInput().get(2); //rl=ru - Hop input1 = hi.getInput().get(0).getInput().get(1); //lhs matrix - Hop pred1 = hi.getInput().get(0).getInput().get(2); //rl=ru + Hop input2 = hi.getInput(1); //rhs matrix + Hop pred2 = hi.getInput(2); //rl=ru + Hop input1 = hi.getInput(0).getInput(1); //lhs matrix + Hop pred1 = hi.getInput(0).getInput(2); //rl=ru if( pred1 instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp)pred1)==1 && pred2 instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp)pred2)==2 @@ -364,19 +364,19 @@ private static Hop removeUnnecessaryCumulativeOp(Hop parent, Hop hi, int pos) { if( hi instanceof UnaryOp && ((UnaryOp)hi).isCumulativeUnaryOperation() ) { - Hop input = hi.getInput().get(0); //input matrix + Hop input = hi.getInput(0); //input matrix if( HopRewriteUtils.isDimsKnown(input) //dims input known && input.getDim1()==1 ) //1 row { OpOp1 op = ((UnaryOp)hi).getOp(); - //remove unnecessary unary cumsum operator + //remove unnecessary unary cumsum operator HopRewriteUtils.replaceChildReference(parent, hi, input, pos); hi = input; LOG.debug("Applied removeUnnecessaryCumulativeOp: "+op); - } + } } return hi; @@ -413,27 +413,27 @@ private static Hop removeUnnecessaryOuterProduct(Hop parent, Hop hi, int pos) if( hi instanceof BinaryOp ) //binary cell operation { OpOp2 bop = ((BinaryOp)hi).getOp(); - Hop left = hi.getInput().get(0); - Hop right = hi.getInput().get(1); + Hop left = hi.getInput(0); + Hop right = hi.getInput(1); //check for matrix-vector column replication: (A + b %*% ones) -> (A + b) if( HopRewriteUtils.isMatrixMultiply(right) //matrix mult with datagen && HopRewriteUtils.isDataGenOpWithConstantValue(right.getInput().get(1), 1) - && right.getInput().get(0).getDim2() == 1 ) //column vector for mv binary + && right.getInput(0).getDim2() == 1 ) //column vector for mv binary { //remove unnecessary outer product - HopRewriteUtils.replaceChildReference(hi, right, right.getInput().get(0), 1 ); + HopRewriteUtils.replaceChildReference(hi, right, right.getInput(0), 1 ); HopRewriteUtils.cleanupUnreferenced(right); LOG.debug("Applied removeUnnecessaryOuterProduct1 (line "+right.getBeginLine()+")"); } //check for matrix-vector row replication: (A + ones %*% b) -> (A + b) else if( HopRewriteUtils.isMatrixMultiply(right) //matrix mult with datagen - && HopRewriteUtils.isDataGenOpWithConstantValue(right.getInput().get(0), 1) - && right.getInput().get(1).getDim1() == 1 ) //row vector for mv binary + && HopRewriteUtils.isDataGenOpWithConstantValue(right.getInput(0), 1) + && right.getInput(1).getDim1() == 1 ) //row vector for mv binary { //remove unnecessary outer product - HopRewriteUtils.replaceChildReference(hi, right, right.getInput().get(1), 1 ); + HopRewriteUtils.replaceChildReference(hi, right, right.getInput(1), 1 ); HopRewriteUtils.cleanupUnreferenced(right); LOG.debug("Applied removeUnnecessaryOuterProduct2 (line "+right.getBeginLine()+")"); @@ -442,11 +442,11 @@ else if( HopRewriteUtils.isMatrixMultiply(right) //matrix mult with datagen else if(HopRewriteUtils.isValidOuterBinaryOp(bop) && HopRewriteUtils.isMatrixMultiply(left) && HopRewriteUtils.isDataGenOpWithConstantValue(left.getInput().get(1), 1) - && (left.getInput().get(0).getDim2() == 1 //outer product - || left.getInput().get(1).getDim1() == 1) + && (left.getInput(0).getDim2() == 1 //outer product + || left.getInput(1).getDim1() == 1) && left.getDim1() != 1 && right.getDim1() == 1 ) //outer vector binary { - Hop hnew = HopRewriteUtils.createBinary(left.getInput().get(0), right, bop, true); + Hop hnew = HopRewriteUtils.createBinary(left.getInput(0), right, bop, true); HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos); HopRewriteUtils.cleanupUnreferenced(hi); diff --git a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteElementwiseMultChainOptimizationAllTest.java b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteElementwiseMultChainOptimizationAllTest.java index 78728d9a71f..15b24534c1e 100644 --- a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteElementwiseMultChainOptimizationAllTest.java +++ b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteElementwiseMultChainOptimizationAllTest.java @@ -23,7 +23,6 @@ import org.junit.Assert; import org.junit.Test; -import org.apache.sysds.api.DMLScript; import org.apache.sysds.common.Types.ExecMode; import org.apache.sysds.hops.OptimizerUtils; import org.apache.sysds.common.Types.ExecType; @@ -74,16 +73,7 @@ public void testMatrixMultChainOptRewritesSP() { private void testRewriteMatrixMultChainOp(String testname, boolean rewrites, ExecType et) { - ExecMode platformOld = rtplatform; - switch( et ){ - case SPARK: rtplatform = ExecMode.SPARK; break; - default: rtplatform = ExecMode.HYBRID; break; - } - - boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; - if( rtplatform == ExecMode.SPARK || rtplatform == ExecMode.HYBRID ) - DMLScript.USE_LOCAL_SPARK_CONFIG = true; - + ExecMode platformOld = setExecMode(et); boolean rewritesOld = OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES; OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewrites; @@ -126,8 +116,7 @@ private void testRewriteMatrixMultChainOp(String testname, boolean rewrites, Exe } finally { OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewritesOld; - rtplatform = platformOld; - DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; + resetExecMode(platformOld); } } } diff --git a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteElementwiseMultChainOptimizationTest.java b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteElementwiseMultChainOptimizationTest.java index d60df3f665c..6c6ede61d70 100644 --- a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteElementwiseMultChainOptimizationTest.java +++ b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteElementwiseMultChainOptimizationTest.java @@ -23,7 +23,6 @@ import org.junit.Assert; import org.junit.Test; -import org.apache.sysds.api.DMLScript; import org.apache.sysds.common.Types.ExecMode; import org.apache.sysds.hops.OptimizerUtils; import org.apache.sysds.common.Types.ExecType; @@ -73,16 +72,7 @@ public void testMatrixMultChainOptRewritesSP() { private void testRewriteMatrixMultChainOp(String testname, boolean rewrites, ExecType et) { - ExecMode platformOld = rtplatform; - switch( et ){ - case SPARK: rtplatform = ExecMode.SPARK; break; - default: rtplatform = ExecMode.HYBRID; break; - } - - boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; - if( rtplatform == ExecMode.SPARK || rtplatform == ExecMode.HYBRID ) - DMLScript.USE_LOCAL_SPARK_CONFIG = true; - + ExecMode platformOld = setExecMode(et); boolean rewritesOld = OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES; OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewrites; @@ -119,8 +109,7 @@ private void testRewriteMatrixMultChainOp(String testname, boolean rewrites, Exe } finally { OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewritesOld; - rtplatform = platformOld; - DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; + resetExecMode(platformOld); } } }