Skip to content

Commit

Permalink
[SYSTEMDS-3801] Fix missing method implementations in ColGroupSDCZeros
Browse files Browse the repository at this point in the history
The previous master version broke the AWARE experiment for the kmeans+ algorithm. This patch fixes that and adds missing methods implementations for DenseBlocks in ColGroupSDCZeros.

After the changes, the runtime additionally was decreased from 40s to 32s for the kmeans+ algorithm on the US Census dataset.

Closes apache#2149.
  • Loading branch information
e-strauss authored and fietenoer committed Dec 5, 2024
1 parent b6662d0 commit 9a98437
Show file tree
Hide file tree
Showing 6 changed files with 74 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -683,7 +683,7 @@ protected void sparseSelection(MatrixBlock selection, P[] points, MatrixBlock re
}
else {
while(c < points.length && points[c].o == of) {
_dict.put(sr, _data.getIndex(it.getDataIndex()), points[c].r, nCol, _colIndexes);
_dict.putSparse(sr, _data.getIndex(it.getDataIndex()), points[c].r, nCol, _colIndexes);
c++;
}
of = it.next();
Expand All @@ -696,7 +696,7 @@ protected void sparseSelection(MatrixBlock selection, P[] points, MatrixBlock re
}

while(of == last && c < points.length && points[c].o == of) {
_dict.put(sr, _data.getIndex(it.getDataIndex()), points[c].r, nCol, _colIndexes);
_dict.putSparse(sr, _data.getIndex(it.getDataIndex()), points[c].r, nCol, _colIndexes);
c++;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -836,7 +836,7 @@ public void sparseSelection(MatrixBlock selection, P[] points, MatrixBlock ret,

while(of < last && c < points.length) {
if(points[c].o == of) {
c = processRow(points, sr, nCol, c, of, _data.getIndex(it.getDataIndex()));
c = processRowSparse(points, sr, nCol, c, of, _data.getIndex(it.getDataIndex()));
of = it.next();
}
else if(points[c].o < of)
Expand All @@ -848,18 +848,46 @@ else if(points[c].o < of)
while(c < points.length && points[c].o < last)
c++;

c = processRow(points, sr, nCol, c, of, _data.getIndex(it.getDataIndex()));
c = processRowSparse(points, sr, nCol, c, of, _data.getIndex(it.getDataIndex()));

}

@Override
protected void denseSelection(MatrixBlock selection, P[] points, MatrixBlock ret, int rl, int ru) {
throw new NotImplementedException();
final DenseBlock dr = ret.getDenseBlock();
final int nCol = _colIndexes.size();
final AIterator it = _indexes.getIterator();
final int last = _indexes.getOffsetToLast();
int c = 0;
int of = it.value();

while(of < last && c < points.length) {
if(points[c].o == of) {
c = processRowDense(points, dr, nCol, c, of, _data.getIndex(it.getDataIndex()));
of = it.next();
}
else if(points[c].o < of)
c++;
else
of = it.next();
}
// increment the c pointer until it is pointing at least to last point or is done.
while(c < points.length && points[c].o < last)
c++;
c = processRowDense(points, dr, nCol, c, of, _data.getIndex(it.getDataIndex()));
}

private int processRowSparse(P[] points, final SparseBlock sr, final int nCol, int c, int of, final int did) {
while(c < points.length && points[c].o == of) {
_dict.putSparse(sr, did, points[c].r, nCol, _colIndexes);
c++;
}
return c;
}

private int processRow(P[] points, final SparseBlock sr, final int nCol, int c, int of, final int did) {
private int processRowDense(P[] points, final DenseBlock dr, final int nCol, int c, int of, final int did) {
while(c < points.length && points[c].o == of) {
_dict.put(sr, did, points[c].r, nCol, _colIndexes);
_dict.putDense(dr, did, points[c].r, nCol, _colIndexes);
c++;
}
return c;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.io.Serializable;

import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex;
import org.apache.sysds.runtime.data.DenseBlock;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.functionobjects.ValueFunction;
import org.apache.sysds.runtime.instructions.cp.CM_COV_Object;
Expand Down Expand Up @@ -87,8 +88,17 @@ public static void correctNan(double[] res, IColIndex colIndexes) {
}

@Override
public void put(SparseBlock sb, int idx, int rowOut, int nCol, IColIndex columns) {
public void putSparse(SparseBlock sb, int idx, int rowOut, int nCol, IColIndex columns) {
for(int i = 0; i < nCol; i++)
sb.append(rowOut, columns.get(i), getValue(idx, i, nCol));
}

@Override
public void putDense(DenseBlock dr, int idx, int rowOut, int nCol, IColIndex columns) {
double[] dv = dr.values(rowOut);
int off = dr.pos(rowOut);
for(int i = 0; i < nCol; i++)
dv[off + columns.get(i)] += getValue(idx, i, nCol);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex;
import org.apache.sysds.runtime.data.DenseBlock;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.functionobjects.Builtin;
import org.apache.sysds.runtime.functionobjects.ValueFunction;
Expand Down Expand Up @@ -989,6 +990,18 @@ public void TSMMToUpperTriangleSparseScaling(SparseBlock left, IColIndex rowsLef
* @param nCol The number of columns in the dictionary
* @param columns The columns to output into.
*/
public void put(SparseBlock sb, int idx, int rowOut, int nCol, IColIndex columns);
public void putSparse(SparseBlock sb, int idx, int rowOut, int nCol, IColIndex columns);

/**
* Put the row specified into the sparse block, via append calls.
*
* @param db The dense block to put into
* @param idx The dictionary index to put in.
* @param rowOut The row in the sparse block to put it into
* @param nCol The number of columns in the dictionary
* @param columns The columns to output into.
*/
public void putDense(DenseBlock db, int idx, int rowOut, int nCol, IColIndex columns);


}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import java.io.Serializable;

import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex;
import org.apache.sysds.runtime.data.DenseBlock;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.functionobjects.Builtin;
import org.apache.sysds.runtime.functionobjects.ValueFunction;
Expand Down Expand Up @@ -526,7 +527,12 @@ public void MMDictScalingSparse(SparseBlock left, IColIndex rowsLeft, IColIndex
}

@Override
public void put(SparseBlock sb, int idx, int rowOut, int nCol, IColIndex columns) {
public void putSparse(SparseBlock sb, int idx, int rowOut, int nCol, IColIndex columns) {
throw new RuntimeException(errMessage);
}

@Override
public void putDense(DenseBlock sb, int idx, int rowOut, int nCol, IColIndex columns) {
throw new RuntimeException(errMessage);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -490,8 +490,13 @@ public void MMDictScalingSparse() {
}

@Test(expected = Exception.class)
public void put() {
d.put(null, 1, 1, 1, null);
public void putDense() {
d.putDense(null, 1, 1, 1, null);
}

@Test(expected = Exception.class)
public void putSparse() {
d.putSparse(null, 1, 1, 1, null);
}

@Test
Expand Down

0 comments on commit 9a98437

Please sign in to comment.