Skip to content

Commit

Permalink
Added a CoGrouped vector UDF support
Browse files Browse the repository at this point in the history
  • Loading branch information
grazy27 committed Nov 23, 2024
1 parent 5f67d45 commit e42631e
Show file tree
Hide file tree
Showing 13 changed files with 570 additions and 58 deletions.
84 changes: 64 additions & 20 deletions src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/DataFrameTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -337,35 +337,79 @@ public void TestGroupedMapUdf()
}
}

private static RecordBatch ArrowBasedCountCharacters(RecordBatch records)

[Fact]
public void TestCoGroupedPandasMapUdf()
{
StringArray nameColumn = records.Column("name") as StringArray;
DataFrame df2 = _spark
.Read()
.Schema("age INT, name STRING")
.Json($"{TestEnvironment.ResourceDirectory}more_people.json");

int characterCount = 0;
var res = _df
.GroupBy("age")
.CoGroup(df2.GroupBy("age"))
.Apply(
new(
[
new StructField("age", new IntegerType()),
new StructField("nameCharCount", new IntegerType())
]),
(rb1, rb2) => ArrowBasedCountCharacters(rb1, rb2))
.Collect()
.ToArray();

for (int i = 0; i < nameColumn.Length; ++i)
Assert.Equal(3, res.Length);
foreach (Row row in res)
{
string current = nameColumn.GetString(i);
characterCount += current.Length;
int? age = row.GetAs<int?>("age");
int charCount = row.GetAs<int>("nameCharCount");

var expected = age switch
{
null => 14,
19 => 17,
30 => 12,
_ => throw new Exception($"Unexpected age: {age}.")
};

Assert.Equal(expected, charCount);
}
}

int ageFieldIndex = records.Schema.GetFieldIndex("age");
Field ageField = records.Schema.GetFieldByIndex(ageFieldIndex);
private static RecordBatch ArrowBasedCountCharacters(params RecordBatch[] records)
{
List<string> names = new();
int? age = default;

// Return 1 record, if we were given any. 0, otherwise.
int returnLength = records.Length > 0 ? 1 : 0;
foreach (var batch in records)
{
age ??= batch.Length > 0
? (batch.Column("age") as Int32Array).GetValue(0)
: age;

return new RecordBatch(
new Schema.Builder()
.Field(ageField)
.Field(f => f.Name("name_CharCount").DataType(Int32Type.Default))
.Build(),
new IArrowArray[]
for (var i = 0; i < batch.Length; i++)
{
records.Column(ageFieldIndex),
new Int32Array.Builder().Append(characterCount).Build()
},
returnLength);
names.Add((batch.Column("name") as StringArray).GetString(i));
}
}

var characterCount = names.Aggregate(0, (prev, str) => prev + str.Length);

var schema = new Schema.Builder()
.Field(f => f.Name("age").DataType(Int32Type.Default))
.Field(f => f.Name("name_CharCount").DataType(Int32Type.Default))
.Build();

return names.Any()
? new(
schema,
[
new Int32Array.Builder().Append(age).Build(),
new Int32Array.Builder().Append(characterCount).Build()
],
1)
: new(schema, [], 0);
}

[Fact]
Expand Down
133 changes: 133 additions & 0 deletions src/csharp/Microsoft.Spark.Worker.UnitTest/CommandExecutorTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1026,6 +1026,139 @@ await arrowWriter.WriteRecordBatchAsync(
Assert.Equal(outputStream.Length, outputStream.Position);
}

[Fact]
// Only Spark 3.0+
public async Task TestArrowCoGroupedMapCommandExecutor()
{
var ipcOptions = new IpcOptions { WriteLegacyIpcFormat = false };
StringArray ConvertStrings(StringArray strings, int groupIndex)
{
return (StringArray)ToArrowArray(
Enumerable.Range(0, strings.Length)
.Select(i => $"group{groupIndex}: {strings.GetString(i)}")
.ToArray());
}

Int64Array ConvertInt64s(Int64Array int64s, int groupIndex)
{
return (Int64Array)ToArrowArray(
Enumerable.Range(0, int64s.Length)
.Select(i => int64s.Values[i] + groupIndex * 100)
.ToArray());
}

Schema resultSchema = new Schema.Builder()
.Field(b => b.Name("group1_arg1").DataType(StringType.Default))
.Field(b => b.Name("group1_arg2").DataType(Int64Type.Default))
.Field(b => b.Name("group2_arg1").DataType(StringType.Default))
.Field(b => b.Name("group2_arg2").DataType(Int64Type.Default))
.Build();

var udfWrapper = new Sql.ArrowCoGroupedMapUdfWrapper(
(batch1, batch2) => new RecordBatch(
resultSchema,
[
ConvertStrings((StringArray)batch1.Column(0), 1),
ConvertInt64s((Int64Array)batch1.Column(1), 1),
ConvertStrings((StringArray)batch2.Column(0), 2),
ConvertInt64s((Int64Array)batch2.Column(1), 2),
],
Math.Min(batch1.Length, batch2.Length)));

var command = new SqlCommand()
{
ArgOffsets = new[] { 0, 1 },
NumChainedFunctions = 1,
WorkerFunction = new Sql.ArrowCoGroupedMapWorkerFunction(udfWrapper.Execute),
SerializerMode = CommandSerDe.SerializedMode.Row,
DeserializerMode = CommandSerDe.SerializedMode.Row
};

var commandPayload = new Worker.CommandPayload()
{
EvalType = UdfUtils.PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
Commands = new[] { command }
};

using var inputStream = new MemoryStream();
using var outputStream = new MemoryStream();
int numRows = 10;

// Write test data to the first input stream.
Schema schema = new Schema.Builder()
.Field(b => b.Name("arg1").DataType(StringType.Default))
.Field(b => b.Name("arg2").DataType(Int64Type.Default))
.Build();
ArrowStreamWriter arrowWriter =
new(inputStream, schema, leaveOpen: true, ipcOptions);
SerDe.Write(inputStream, 2);
await arrowWriter.WriteRecordBatchAsync(
new RecordBatch(
schema,
[
ToArrowArray(
Enumerable.Range(0, numRows)
.Select(i => $"group1_val{i}")
.ToArray()),
ToArrowArray(
Enumerable.Range(0, numRows)
.Select(i => (long)i)
.ToArray())
],
numRows));
SerDe.Write(inputStream, [0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]);
await arrowWriter.WriteRecordBatchAsync(
new RecordBatch(
schema,
[
ToArrowArray(
Enumerable.Range(0, numRows)
.Select(i => $"group2_val{i}")
.ToArray()),
ToArrowArray(
Enumerable.Range(0, numRows)
.Select(i => (long)i)
.ToArray())
],
numRows));
SerDe.Write(inputStream, [0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00]);
SerDe.Write(inputStream, 0);

inputStream.Seek(0, SeekOrigin.Begin);

CommandExecutorStat stat = new CommandExecutor(new("3.0.0.0")).Execute(
inputStream,
outputStream,
0,
commandPayload);

// Validate that all the data on both streams is read.
Assert.Equal(inputStream.Length, inputStream.Position);
Assert.Equal(numRows, stat.NumEntriesProcessed);

// Validate the output stream.
outputStream.Seek(0, SeekOrigin.Begin);
int arrowLength = SerDe.ReadInt32(outputStream);
Assert.Equal((int)SpecialLengths.START_ARROW_STREAM, arrowLength);
var arrowReader = new ArrowStreamReader(outputStream);
RecordBatch outputBatch = await arrowReader.ReadNextRecordBatchAsync();

Assert.Equal(1, outputBatch.ColumnCount);
var structArray = (StructArray)outputBatch.Column(0);
Assert.Equal(4, structArray.Fields.Count);

for (int i = 0; i < numRows; ++i)
{
Assert.Equal($"group1: group1_val{i}", (structArray.Fields[0] as StringArray).GetString(i));
Assert.Equal(i + 100, (structArray.Fields[1] as Int64Array).Values[i]);
Assert.Equal($"group2: group2_val{i}", (structArray.Fields[2] as StringArray).GetString(i));
Assert.Equal(i + 200, (structArray.Fields[3] as Int64Array).Values[i]);
}

CheckEOS(outputStream, ipcOptions);
Assert.Equal(outputStream.Length, outputStream.Position);
}

[Theory]
[MemberData(nameof(CommandExecutorData.Data), MemberType = typeof(CommandExecutorData))]
public void TestRDDCommandExecutor(Version sparkVersion, IpcOptions ipcOptions)
Expand Down
Loading

0 comments on commit e42631e

Please sign in to comment.