Skip to content

Commit

Permalink
[SYSTEMDS-3804] New rewrite for reverse sequences
Browse files Browse the repository at this point in the history
This patch adds a new rewrite rev(seq(1,n)) -> seq(n,1), a pattern
we recently saw in a script on vectorized time series forecasting.
  • Loading branch information
mboehm7 committed Dec 9, 2024
1 parent b5b6f37 commit 082cf89
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,9 @@ private void rule_AlgebraicSimplification(Hop hop, boolean descendFirst)
hi = canonicalizeMatrixMultScalarAdd(hi); //e.g., eps+U%*%t(V) -> U%*%t(V)+eps, U%*%t(V)-eps -> U%*%t(V)+(-eps)
hi = simplifyCTableWithConstMatrixInputs(hi); //e.g., table(X, matrix(1,...)) -> table(X, 1)
hi = removeUnnecessaryCTable(hop, hi, i); //e.g., sum(table(X, 1)) -> nrow(X) and sum(table(1, Y)) -> nrow(Y) and sum(table(X, Y)) -> nrow(X)
hi = simplifyConstantConjunction(hop, hi, i); //e.g., a & !a -> FALSE
hi = simplifyConstantConjunction(hop, hi, i); //e.g., a & !a -> FALSE
hi = simplifyReverseOperation(hop, hi, i); //e.g., table(seq(1,nrow(X),1),seq(nrow(X),1,-1)) %*% X -> rev(X)
hi = simplifyReverseSequence(hop, hi, i); //e.g., rev(seq(1,n)) -> seq(n,1)
if(OptimizerUtils.ALLOW_OPERATOR_FUSION)
hi = simplifyMultiBinaryToBinaryOperation(hi); //e.g., 1-X*Y -> X 1-* Y
hi = simplifyDistributiveBinaryOperation(hop, hi, i);//e.g., (X-Y*X) -> (1-Y)*X
Expand Down Expand Up @@ -798,6 +799,28 @@ private static Hop simplifyReverseOperation( Hop parent, Hop hi, int pos )

return hi;
}

private static Hop simplifyReverseSequence( Hop parent, Hop hi, int pos )
{
if( HopRewriteUtils.isReorg(hi, ReOrgOp.REV)
&& HopRewriteUtils.isBasic1NSequence(hi.getInput(0))
&& hi.getInput(0).getParent().size() == 1) //only consumer
{
DataGenOp seq = (DataGenOp) hi.getInput(0);
Hop from = seq.getInput().get(seq.getParamIndex(Statement.SEQ_FROM));
Hop to = seq.getInput().get(seq.getParamIndex(Statement.SEQ_TO));
seq.getInput().set(seq.getParamIndex(Statement.SEQ_FROM), to);
seq.getInput().set(seq.getParamIndex(Statement.SEQ_TO), from);
seq.getInput().set(seq.getParamIndex(Statement.SEQ_INCR), new LiteralOp(-1));

HopRewriteUtils.replaceChildReference(parent, hi, seq, pos);
HopRewriteUtils.cleanupUnreferenced(hi, seq);
hi = seq;
LOG.debug("Applied simplifyReverseSequence (line "+hi.getBeginLine()+").");
}

return hi;
}

private static Hop simplifyMultiBinaryToBinaryOperation( Hop hi )
{
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.sysds.test.functions.rewrite;

import org.junit.Assert;
import org.junit.Test;

import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;

public class RewriteRemoveUnnecessaryRevTest extends AutomatedTestBase
{
private static final String TEST_NAME1 = "RewriteRemoveUnnecessaryRev";

private static final String TEST_DIR = "functions/rewrite/";
private static final String TEST_CLASS_DIR = TEST_DIR + RewritePushdownSumBinaryMult.class.getSimpleName() + "/";

@Override
public void setUp() {
TestUtils.clearAssertionInformation();
addTestConfiguration( TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) );
}

@Test
public void testRemoveSeqRevRewrite() {
testRewriteRemoveSeqRev( TEST_NAME1, true );
}

@Test
public void testRemoveSeqRevNoRewrite() {
testRewriteRemoveSeqRev( TEST_NAME1, false );
}

private void testRewriteRemoveSeqRev( String testname, boolean rewrites )
{
boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
int rows = 1001;

try
{
TestConfiguration config = getTestConfiguration(testname);
loadTestConfiguration(config);

String HOME = SCRIPT_DIR + TEST_DIR;
fullDMLScriptName = HOME + testname + ".dml";
programArgs = new String[]{ "-stats","-args", String.valueOf(rows), output("Scalar") };
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites;

runTest(true, false, null, -1);

//compare scalars
int ret = (int)readDMLScalarFromOutputDir("Scalar").get(new CellIndex(1,1)).doubleValue();
Assert.assertEquals(ret, rows*(rows+1)/2);
if( rewrites )
Assert.assertFalse(heavyHittersContainsString("rev"));
}
finally {
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
}
}
}
31 changes: 31 additions & 0 deletions src/test/scripts/functions/rewrite/RewriteRemoveUnnecessaryRev.dml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------

rows = $1;

# to be rewritten to: seq(rows,1)
X = rev(seq(1,rows))

while(FALSE){}

R = sum(X);
write(R, $2)

0 comments on commit 082cf89

Please sign in to comment.