Skip to content

Commit

Permalink
[SYSTEMDS-3797] Fix rewrite for trace on reorg operations
Browse files Browse the repository at this point in the history
This patch fixes the rewrite for removing unnecessary reorg operations
such as sum(t(X)) or sum(rev(X)) for trace aggregations which only
consume a subset of values. Furthermore, we generalize this rewrite
to now eliminate all reorg operations that are guaranteed to preserve
all values (e.g., transpose/reshape/rev/roll, but not for diagM2V and
sort with index return).

Thanks to Jannik Lindemann for catching this issue.
  • Loading branch information
mboehm7 committed Nov 28, 2024
1 parent 9a318ee commit 9b940f7
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 16 deletions.
4 changes: 4 additions & 0 deletions src/main/java/org/apache/sysds/common/Types.java
Original file line number Diff line number Diff line change
Expand Up @@ -751,6 +751,10 @@ public enum ReOrgOp {
DIAG, //DIAG_V2M and DIAG_M2V could not be distinguished if sizes unknown
RESHAPE, REV, ROLL, SORT, TRANS;

public boolean preservesValues() {
return this != DIAG && this != SORT;
}

@Override
public String toString() {
switch(this) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -980,23 +980,21 @@ private static Hop simplifyBushyBinaryOperation( Hop parent, Hop hi, int pos )

private static Hop simplifyUnaryAggReorgOperation( Hop parent, Hop hi, int pos )
{
if( hi instanceof AggUnaryOp && ((AggUnaryOp)hi).getDirection()==Direction.RowCol //full uagg
&& hi.getInput().get(0) instanceof ReorgOp ) //reorg operation
if( hi instanceof AggUnaryOp && ((AggUnaryOp)hi).getDirection()==Direction.RowCol
&& ((AggUnaryOp)hi).getOp() != AggOp.TRACE //full uagg
&& hi.getInput().get(0) instanceof ReorgOp ) //reorg operation
{
ReorgOp rop = (ReorgOp)hi.getInput().get(0);
if( (rop.getOp()==ReOrgOp.TRANS || rop.getOp()==ReOrgOp.RESHAPE
|| rop.getOp() == ReOrgOp.REV ) //valid reorg
&& rop.getParent().size()==1 ) //uagg only reorg consumer
if( rop.getOp().preservesValues() //valid reorg
&& rop.getParent().size()==1 ) //uagg only reorg consumer
{
Hop input = rop.getInput().get(0);
HopRewriteUtils.removeAllChildReferences(hi);
HopRewriteUtils.removeAllChildReferences(rop);
HopRewriteUtils.addChildReference(hi, input);

LOG.debug("Applied simplifyUnaryAggReorgOperation");
}
}

return hi;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,18 +85,12 @@ private void testRewriteTraceMatrixMult(String testname, boolean rewrites) {
TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R");

//check trace operator existence
String uaktrace = "uaktrace";
long numTrace = Statistics.getCPHeavyHitterCount(uaktrace);

if(rewrites)
Assert.assertTrue(numTrace == 0);
else
Assert.assertTrue(numTrace == 1);

long numTrace = Statistics.getCPHeavyHitterCount("uaktrace");
Assert.assertTrue(numTrace == (rewrites ? 1 : 2));
Assert.assertTrue(heavyHittersContainsString("rev"));
}
finally {
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ B = as.matrix(readMM(paste(args[1], "B.mtx", sep="")))

# Perform the matrix operation
R = sum(diag(A %*% B))
rA = A;
for(i in 1:nrow(rA)) {
rA[,i] = rev(rA[,i])
}
R = R + sum(diag(rA))

# Write the result scalar R
write(R, paste(args[2], "R" ,sep=""))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ B = read($2)

# Perform the operation
R = trace(A %*% B)
R = R + trace(rev(A))

# Write the result R
write(R, $3)

0 comments on commit 9b940f7

Please sign in to comment.