Skip to content

Commit

Permalink
Use "union" in partial aggregation output
Browse files Browse the repository at this point in the history
Since partial aggregation can be disabled at runtime,
both aggregated and the input data need to be passed
through to the final step.
This change extends the partial aggregation output
with additional columns to use for raw input
and uses those columns to send raw input in case
partial aggregation is disabled.
An additional column contains information which
set of channels should be used by the final step.
  • Loading branch information
lukasz-stec committed May 31, 2022
1 parent 629cf5e commit f92ea04
Show file tree
Hide file tree
Showing 35 changed files with 804 additions and 224 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import io.trino.operator.aggregation.builder.InMemoryHashAggregationBuilder;
import io.trino.operator.aggregation.builder.SpillableHashAggregationBuilder;
import io.trino.operator.aggregation.partial.PartialAggregationController;
import io.trino.operator.aggregation.partial.PartialAggregationOutputProcessor;
import io.trino.operator.aggregation.partial.SkipAggregationBuilder;
import io.trino.operator.scalar.CombineHashFunction;
import io.trino.spi.Page;
Expand Down Expand Up @@ -62,6 +63,7 @@ public static class HashAggregationOperatorFactory
private final Step step;
private final boolean produceDefaultOutput;
private final List<AggregatorFactory> aggregatorFactories;
private final Optional<PartialAggregationOutputProcessor> partialAggregationOutputProcessor;
private final Optional<Integer> hashChannel;
private final Optional<Integer> groupIdChannel;

Expand All @@ -83,6 +85,8 @@ public HashAggregationOperatorFactory(
PlanNodeId planNodeId,
List<? extends Type> groupByTypes,
List<Integer> groupByChannels,
List<? extends Type> aggregationRawInputTypes,
List<Integer> aggregationInputChannels,
List<Integer> globalAggregationGroupIds,
Step step,
List<AggregatorFactory> aggregatorFactories,
Expand All @@ -98,6 +102,14 @@ public HashAggregationOperatorFactory(
planNodeId,
groupByTypes,
groupByChannels,
step.isOutputPartial() ?
Optional.of(new PartialAggregationOutputProcessor(
groupByChannels,
hashChannel,
aggregatorFactories,
aggregationRawInputTypes,
aggregationInputChannels)) :
Optional.empty(),
globalAggregationGroupIds,
step,
false,
Expand All @@ -122,6 +134,8 @@ public HashAggregationOperatorFactory(
PlanNodeId planNodeId,
List<? extends Type> groupByTypes,
List<Integer> groupByChannels,
List<? extends Type> aggregationRawInputTypes,
List<Integer> aggregationInputChannels,
List<Integer> globalAggregationGroupIds,
Step step,
boolean produceDefaultOutput,
Expand All @@ -141,6 +155,14 @@ public HashAggregationOperatorFactory(
planNodeId,
groupByTypes,
groupByChannels,
step.isOutputPartial() ?
Optional.of(new PartialAggregationOutputProcessor(
groupByChannels,
hashChannel,
aggregatorFactories,
aggregationRawInputTypes,
aggregationInputChannels)) :
Optional.empty(),
globalAggregationGroupIds,
step,
produceDefaultOutput,
Expand All @@ -164,6 +186,7 @@ public HashAggregationOperatorFactory(
PlanNodeId planNodeId,
List<? extends Type> groupByTypes,
List<Integer> groupByChannels,
Optional<PartialAggregationOutputProcessor> partialAggregationOutputProcessor,
List<Integer> globalAggregationGroupIds,
Step step,
boolean produceDefaultOutput,
Expand All @@ -186,6 +209,10 @@ public HashAggregationOperatorFactory(
this.groupIdChannel = requireNonNull(groupIdChannel, "groupIdChannel is null");
this.groupByTypes = ImmutableList.copyOf(groupByTypes);
this.groupByChannels = ImmutableList.copyOf(groupByChannels);
if (step.isOutputPartial()) {
checkArgument(partialAggregationOutputProcessor.isPresent(), "partialAggregationOutputProcessor must be present in the partial step");
}
this.partialAggregationOutputProcessor = requireNonNull(partialAggregationOutputProcessor, "partialAggregationOutputProcessor is null");
this.globalAggregationGroupIds = ImmutableList.copyOf(globalAggregationGroupIds);
this.step = step;
this.produceDefaultOutput = produceDefaultOutput;
Expand All @@ -211,6 +238,7 @@ public Operator createOperator(DriverContext driverContext)
operatorContext,
groupByTypes,
groupByChannels,
partialAggregationOutputProcessor,
globalAggregationGroupIds,
step,
produceDefaultOutput,
Expand Down Expand Up @@ -243,6 +271,7 @@ public OperatorFactory duplicate()
planNodeId,
groupByTypes,
groupByChannels,
partialAggregationOutputProcessor,
globalAggregationGroupIds,
step,
produceDefaultOutput,
Expand All @@ -265,6 +294,7 @@ public OperatorFactory duplicate()
private final Optional<PartialAggregationController> partialAggregationController;
private final List<Type> groupByTypes;
private final List<Integer> groupByChannels;
private final Optional<PartialAggregationOutputProcessor> partialAggregationOutputProcessor;
private final List<Integer> globalAggregationGroupIds;
private final Step step;
private final boolean produceDefaultOutput;
Expand Down Expand Up @@ -299,6 +329,7 @@ private HashAggregationOperator(
OperatorContext operatorContext,
List<Type> groupByTypes,
List<Integer> groupByChannels,
Optional<PartialAggregationOutputProcessor> partialAggregationOutputProcessor,
List<Integer> globalAggregationGroupIds,
Step step,
boolean produceDefaultOutput,
Expand All @@ -321,9 +352,12 @@ private HashAggregationOperator(
requireNonNull(aggregatorFactories, "aggregatorFactories is null");
requireNonNull(operatorContext, "operatorContext is null");
checkArgument(partialAggregationController.isEmpty() || step.isOutputPartial(), "partialAggregationController should be present only for partial aggregation");

if (step.isOutputPartial()) {
checkArgument(partialAggregationOutputProcessor.isPresent(), "partialAggregationOutputProcessor must be present in the partial step");
}
this.groupByTypes = ImmutableList.copyOf(groupByTypes);
this.groupByChannels = ImmutableList.copyOf(groupByChannels);
this.partialAggregationOutputProcessor = requireNonNull(partialAggregationOutputProcessor, "partialAggregationOutputProcessor is null");
this.globalAggregationGroupIds = ImmutableList.copyOf(globalAggregationGroupIds);
this.aggregatorFactories = ImmutableList.copyOf(aggregatorFactories);
this.hashChannel = requireNonNull(hashChannel, "hashChannel is null");
Expand Down Expand Up @@ -390,7 +424,7 @@ public void addInput(Page page)
.map(PartialAggregationController::isPartialAggregationDisabled)
.orElse(false);
if (step.isOutputPartial() && partialAggregationDisabled) {
aggregationBuilder = new SkipAggregationBuilder(groupByChannels, hashChannel, aggregatorFactories, memoryContext);
aggregationBuilder = new SkipAggregationBuilder(partialAggregationOutputProcessor.get(), memoryContext);
}
else if (step.isOutputPartial() || !spillEnabled || !isSpillable()) {
// TODO: We ignore spillEnabled here if any aggregate has ORDER BY clause or DISTINCT because they are not yet implemented for spilling.
Expand All @@ -400,6 +434,7 @@ else if (step.isOutputPartial() || !spillEnabled || !isSpillable()) {
expectedGroups,
groupByTypes,
groupByChannels,
partialAggregationOutputProcessor,
hashChannel,
operatorContext,
maxPartialMemory,
Expand All @@ -421,6 +456,7 @@ else if (step.isOutputPartial() || !spillEnabled || !isSpillable()) {
expectedGroups,
groupByTypes,
groupByChannels,
partialAggregationOutputProcessor,
hashChannel,
operatorContext,
memoryLimitForMerge,
Expand Down Expand Up @@ -584,7 +620,13 @@ private Page getGlobalAggregationOutput()
if (output.isEmpty()) {
return null;
}
return output.build();

Page page = output.build();
if (step.isOutputPartial()) {
page = partialAggregationOutputProcessor.get().processAggregatedPage(page);
}

return page;
}

private static long calculateDefaultOutputHash(List<Type> groupByChannels, int groupIdChannel, int groupId)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import java.util.OptionalInt;
import java.util.function.Supplier;

import static com.google.common.base.Preconditions.checkArgument;
import static java.util.Objects.requireNonNull;

public class AggregatorFactory
Expand All @@ -30,7 +29,9 @@ public class AggregatorFactory
private final Step step;
private final Type intermediateType;
private final Type finalType;
private final List<Integer> inputChannels;
private final List<Integer> rawInputChannels;
private final OptionalInt intermediateStateChannel;
private final OptionalInt rawInputMaskChannel;
private final OptionalInt maskChannel;
private final boolean spillable;
private final List<Supplier<Object>> lambdaProviders;
Expand All @@ -41,6 +42,7 @@ public AggregatorFactory(
Type intermediateType,
Type finalType,
List<Integer> inputChannels,
OptionalInt rawInputMaskChannel,
OptionalInt maskChannel,
boolean spillable,
List<Supplier<Object>> lambdaProviders)
Expand All @@ -49,12 +51,19 @@ public AggregatorFactory(
this.step = requireNonNull(step, "step is null");
this.intermediateType = requireNonNull(intermediateType, "intermediateType is null");
this.finalType = requireNonNull(finalType, "finalType is null");
this.inputChannels = ImmutableList.copyOf(requireNonNull(inputChannels, "inputChannels is null"));
requireNonNull(inputChannels, "inputChannels is null");
if (step.isInputRaw()) {
intermediateStateChannel = OptionalInt.empty();
this.rawInputChannels = ImmutableList.copyOf(inputChannels);
}
else {
intermediateStateChannel = OptionalInt.of(inputChannels.get(0));
this.rawInputChannels = ImmutableList.copyOf(inputChannels.subList(1, inputChannels.size()));
}
this.rawInputMaskChannel = requireNonNull(rawInputMaskChannel, "rawInputMaskChannel is null");
this.maskChannel = requireNonNull(maskChannel, "maskChannel is null");
this.spillable = spillable;
this.lambdaProviders = ImmutableList.copyOf(requireNonNull(lambdaProviders, "lambdaProviders is null"));

checkArgument(step.isInputRaw() || inputChannels.size() == 1, "expected 1 input channel for intermediate aggregation");
}

public Aggregator createAggregator()
Expand All @@ -66,6 +75,7 @@ public Aggregator createAggregator()
else {
accumulator = accumulatorFactory.createIntermediateAccumulator(lambdaProviders);
}
List<Integer> inputChannels = intermediateStateChannel.isEmpty() ? rawInputChannels : ImmutableList.of(intermediateStateChannel.getAsInt());
return new Aggregator(accumulator, step, intermediateType, finalType, inputChannels, maskChannel);
}

Expand All @@ -78,7 +88,7 @@ public GroupedAggregator createGroupedAggregator()
else {
accumulator = accumulatorFactory.createGroupedIntermediateAccumulator(lambdaProviders);
}
return new GroupedAggregator(accumulator, step, intermediateType, finalType, inputChannels, maskChannel);
return new GroupedAggregator(accumulator, step, intermediateType, finalType, rawInputChannels, intermediateStateChannel, rawInputMaskChannel, maskChannel);
}

public GroupedAggregator createUnspillGroupedAggregator(Step step, int inputChannel)
Expand All @@ -90,11 +100,21 @@ public GroupedAggregator createUnspillGroupedAggregator(Step step, int inputChan
else {
accumulator = accumulatorFactory.createGroupedIntermediateAccumulator(lambdaProviders);
}
return new GroupedAggregator(accumulator, step, intermediateType, finalType, ImmutableList.of(inputChannel), maskChannel);
return new GroupedAggregator(accumulator, step, intermediateType, finalType, ImmutableList.of(inputChannel), OptionalInt.of(inputChannel), OptionalInt.empty(), maskChannel);
}

public boolean isSpillable()
{
return spillable;
}

public OptionalInt getMaskChannel()
{
return maskChannel;
}

public Type getIntermediateType()
{
return intermediateType;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,12 @@
import io.trino.spi.Page;
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.block.ByteArrayBlock;
import io.trino.spi.block.RunLengthEncodedBlock;
import io.trino.spi.type.Type;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.AggregationNode.Step;
import it.unimi.dsi.fastutil.ints.IntArrayList;

import java.util.List;
import java.util.Optional;
Expand All @@ -35,18 +38,30 @@ public class GroupedAggregator
private AggregationNode.Step step;
private final Type intermediateType;
private final Type finalType;
private final int[] inputChannels;
private final int[] rawInputChannels;
private final OptionalInt intermediateStateChannel;
private final OptionalInt rawInputMaskChannel;
private final OptionalInt maskChannel;

public GroupedAggregator(GroupedAccumulator accumulator, Step step, Type intermediateType, Type finalType, List<Integer> inputChannels, OptionalInt maskChannel)
public GroupedAggregator(
GroupedAccumulator accumulator,
Step step,
Type intermediateType,
Type finalType,
List<Integer> rawInputChannels,
OptionalInt intermediateStateChannel,
OptionalInt rawInputMaskChannel,
OptionalInt maskChannel)
{
this.accumulator = requireNonNull(accumulator, "accumulator is null");
this.step = requireNonNull(step, "step is null");
this.intermediateType = requireNonNull(intermediateType, "intermediateType is null");
this.finalType = requireNonNull(finalType, "finalType is null");
this.inputChannels = Ints.toArray(requireNonNull(inputChannels, "inputChannels is null"));
this.rawInputChannels = Ints.toArray(requireNonNull(rawInputChannels, "inputChannels is null"));
this.intermediateStateChannel = requireNonNull(intermediateStateChannel, "intermediateStateChannel is null");
this.rawInputMaskChannel = requireNonNull(rawInputMaskChannel, "rawInputMaskChannel is null");
this.maskChannel = requireNonNull(maskChannel, "maskChannel is null");
checkArgument(step.isInputRaw() || inputChannels.size() == 1, "expected 1 input channel for intermediate aggregation");
checkArgument(step.isInputRaw() || intermediateStateChannel.isPresent(), "expected intermediateStateChannel for intermediate aggregation but got %s ", intermediateStateChannel);
}

public long getEstimatedSize()
Expand All @@ -67,11 +82,55 @@ public Type getType()
public void processPage(GroupByIdBlock groupIds, Page page)
{
if (step.isInputRaw()) {
accumulator.addInput(groupIds, page.getColumns(inputChannels), getMaskBlock(page));
accumulator.addInput(groupIds, page.getColumns(rawInputChannels), getMaskBlock(page));
return;
}
else {
accumulator.addIntermediate(groupIds, page.getBlock(inputChannels[0]));

if (rawInputMaskChannel.isEmpty()) {
// process partially aggregated data
accumulator.addIntermediate(groupIds, page.getBlock(intermediateStateChannel.getAsInt()));
return;
}
Block rawInputMaskBlock = page.getBlock(rawInputMaskChannel.getAsInt());
if (rawInputMaskBlock instanceof RunLengthEncodedBlock) {
if (rawInputMaskBlock.isNull(0)) {
// process partially aggregated data
accumulator.addIntermediate(groupIds, page.getBlock(intermediateStateChannel.getAsInt()));
}
else {
// process raw data
accumulator.addInput(groupIds, page.getColumns(rawInputChannels), getMaskBlock(page));
}
return;
}

// rawInputMaskBlock has potentially mixed partially aggregated and raw data
Block maskBlock = getMaskBlock(page)
.map(mask -> andMasks(mask, rawInputMaskBlock))
.orElse(rawInputMaskBlock);
accumulator.addInput(groupIds, page.getColumns(rawInputChannels), Optional.of(maskBlock));
IntArrayList intermediatePositions = filterByNull(rawInputMaskBlock);
Block intermediateStateBlock = page.getBlock(intermediateStateChannel.getAsInt());

if (intermediatePositions.size() != rawInputMaskBlock.getPositionCount()) {
// some rows were eliminated by the filter
intermediateStateBlock = intermediateStateBlock.getPositions(intermediatePositions.elements(), 0, intermediatePositions.size());
groupIds = new GroupByIdBlock(
groupIds.getGroupCount(),
groupIds.getPositions(intermediatePositions.elements(), 0, intermediatePositions.size()));
}

accumulator.addIntermediate(groupIds, intermediateStateBlock);
}

private Block andMasks(Block mask1, Block mask2)
{
int positionCount = mask1.getPositionCount();
byte[] mask = new byte[positionCount];
for (int i = 0; i < positionCount; i++) {
mask[i] = (byte) ((!mask1.isNull(i) && mask1.getByte(i, 0) == 1 && !mask2.isNull(i) && mask2.getByte(i, 0) == 1) ? 1 : 0);
}
return new ByteArrayBlock(positionCount, Optional.empty(), mask);
}

private Optional<Block> getMaskBlock(Page page)
Expand Down Expand Up @@ -107,4 +166,18 @@ public Type getSpillType()
{
return intermediateType;
}

private static IntArrayList filterByNull(Block mask)
{
int positions = mask.getPositionCount();

IntArrayList ids = new IntArrayList(positions);
for (int i = 0; i < positions; ++i) {
if (mask.isNull(i)) {
ids.add(i);
}
}

return ids;
}
}
Loading

0 comments on commit f92ea04

Please sign in to comment.