Skip to content

Commit

Permalink
Use "union" in partial aggregation output
Browse files Browse the repository at this point in the history
Since partial aggregation can be disabled at runtime,
both aggregated and the input data need to be passed
through to the final step.
This change extends the partial aggregation output
with additional columns to use for raw input
and uses those columns to send raw input in case
partial aggregation is disabled.
An additional column contains information which
set of channels should be used by the final step.
  • Loading branch information
lukasz-stec committed Apr 27, 2022
1 parent 682c0a9 commit c5b8233
Show file tree
Hide file tree
Showing 34 changed files with 804 additions and 209 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import io.trino.operator.aggregation.builder.InMemoryHashAggregationBuilder;
import io.trino.operator.aggregation.builder.SpillableHashAggregationBuilder;
import io.trino.operator.aggregation.partial.PartialAggregationController;
import io.trino.operator.aggregation.partial.PartialAggregationOutputProcessor;
import io.trino.operator.aggregation.partial.SkipAggregationBuilder;
import io.trino.operator.scalar.CombineHashFunction;
import io.trino.spi.Page;
Expand Down Expand Up @@ -62,6 +63,7 @@ public static class HashAggregationOperatorFactory
private final Step step;
private final boolean produceDefaultOutput;
private final List<AggregatorFactory> aggregatorFactories;
private final PartialAggregationOutputProcessor partialAggregationOutputProcessor;
private final Optional<Integer> hashChannel;
private final Optional<Integer> groupIdChannel;

Expand All @@ -83,6 +85,8 @@ public HashAggregationOperatorFactory(
PlanNodeId planNodeId,
List<? extends Type> groupByTypes,
List<Integer> groupByChannels,
List<? extends Type> aggregationInputTypes,
List<Integer> aggregationInputChannels,
List<Integer> globalAggregationGroupIds,
Step step,
List<AggregatorFactory> aggregatorFactories,
Expand All @@ -98,6 +102,8 @@ public HashAggregationOperatorFactory(
planNodeId,
groupByTypes,
groupByChannels,
aggregationInputTypes,
aggregationInputChannels,
globalAggregationGroupIds,
step,
false,
Expand All @@ -122,6 +128,8 @@ public HashAggregationOperatorFactory(
PlanNodeId planNodeId,
List<? extends Type> groupByTypes,
List<Integer> groupByChannels,
List<? extends Type> aggregationInputTypes,
List<Integer> aggregationInputChannels,
List<Integer> globalAggregationGroupIds,
Step step,
boolean produceDefaultOutput,
Expand All @@ -141,6 +149,8 @@ public HashAggregationOperatorFactory(
planNodeId,
groupByTypes,
groupByChannels,
aggregationInputTypes,
aggregationInputChannels,
globalAggregationGroupIds,
step,
produceDefaultOutput,
Expand All @@ -164,6 +174,58 @@ public HashAggregationOperatorFactory(
PlanNodeId planNodeId,
List<? extends Type> groupByTypes,
List<Integer> groupByChannels,
List<? extends Type> aggregationInputTypes,
List<Integer> aggregationInputChannels,
List<Integer> globalAggregationGroupIds,
Step step,
boolean produceDefaultOutput,
List<AggregatorFactory> aggregatorFactories,
Optional<Integer> hashChannel,
Optional<Integer> groupIdChannel,
int expectedGroups,
Optional<DataSize> maxPartialMemory,
boolean spillEnabled,
DataSize memoryLimitForMerge,
DataSize memoryLimitForMergeWithMemory,
SpillerFactory spillerFactory,
JoinCompiler joinCompiler,
BlockTypeOperators blockTypeOperators,
Optional<PartialAggregationController> partialAggregationController)
{
this(
operatorId,
planNodeId,
groupByTypes,
groupByChannels,
PartialAggregationOutputProcessor.create(
groupByChannels,
hashChannel,
aggregatorFactories,
aggregationInputTypes,
aggregationInputChannels),
globalAggregationGroupIds,
step,
produceDefaultOutput,
aggregatorFactories,
hashChannel,
groupIdChannel,
expectedGroups,
maxPartialMemory,
spillEnabled,
memoryLimitForMerge,
memoryLimitForMergeWithMemory,
spillerFactory,
joinCompiler,
blockTypeOperators,
partialAggregationController);
}

private HashAggregationOperatorFactory(
int operatorId,
PlanNodeId planNodeId,
List<? extends Type> groupByTypes,
List<Integer> groupByChannels,
PartialAggregationOutputProcessor partialAggregationOutputProcessor,
List<Integer> globalAggregationGroupIds,
Step step,
boolean produceDefaultOutput,
Expand All @@ -186,6 +248,7 @@ public HashAggregationOperatorFactory(
this.groupIdChannel = requireNonNull(groupIdChannel, "groupIdChannel is null");
this.groupByTypes = ImmutableList.copyOf(groupByTypes);
this.groupByChannels = ImmutableList.copyOf(groupByChannels);
this.partialAggregationOutputProcessor = requireNonNull(partialAggregationOutputProcessor, "partialAggregationOutputProcessor is null");
this.globalAggregationGroupIds = ImmutableList.copyOf(globalAggregationGroupIds);
this.step = step;
this.produceDefaultOutput = produceDefaultOutput;
Expand All @@ -211,6 +274,7 @@ public Operator createOperator(DriverContext driverContext)
operatorContext,
groupByTypes,
groupByChannels,
partialAggregationOutputProcessor,
globalAggregationGroupIds,
step,
produceDefaultOutput,
Expand Down Expand Up @@ -243,6 +307,7 @@ public OperatorFactory duplicate()
planNodeId,
groupByTypes,
groupByChannels,
partialAggregationOutputProcessor,
globalAggregationGroupIds,
step,
produceDefaultOutput,
Expand All @@ -265,6 +330,7 @@ public OperatorFactory duplicate()
private final Optional<PartialAggregationController> partialAggregationController;
private final List<Type> groupByTypes;
private final List<Integer> groupByChannels;
private final PartialAggregationOutputProcessor partialAggregationOutputProcessor;
private final List<Integer> globalAggregationGroupIds;
private final Step step;
private final boolean produceDefaultOutput;
Expand Down Expand Up @@ -299,6 +365,7 @@ private HashAggregationOperator(
OperatorContext operatorContext,
List<Type> groupByTypes,
List<Integer> groupByChannels,
PartialAggregationOutputProcessor partialAggregationOutputProcessor,
List<Integer> globalAggregationGroupIds,
Step step,
boolean produceDefaultOutput,
Expand All @@ -324,6 +391,7 @@ private HashAggregationOperator(

this.groupByTypes = ImmutableList.copyOf(groupByTypes);
this.groupByChannels = ImmutableList.copyOf(groupByChannels);
this.partialAggregationOutputProcessor = requireNonNull(partialAggregationOutputProcessor, "partialAggregationOutputProcessor is null");
this.globalAggregationGroupIds = ImmutableList.copyOf(globalAggregationGroupIds);
this.aggregatorFactories = ImmutableList.copyOf(aggregatorFactories);
this.hashChannel = requireNonNull(hashChannel, "hashChannel is null");
Expand Down Expand Up @@ -390,7 +458,7 @@ public void addInput(Page page)
.map(PartialAggregationController::isPartialAggregationDisabled)
.orElse(false);
if (step.isOutputPartial() && partialAggregationDisabled) {
aggregationBuilder = new SkipAggregationBuilder(groupByChannels, hashChannel, aggregatorFactories, memoryContext);
aggregationBuilder = new SkipAggregationBuilder(partialAggregationOutputProcessor, memoryContext);
}
else if (step.isOutputPartial() || !spillEnabled || !isSpillable()) {
// TODO: We ignore spillEnabled here if any aggregate has ORDER BY clause or DISTINCT because they are not yet implemented for spilling.
Expand All @@ -400,6 +468,7 @@ else if (step.isOutputPartial() || !spillEnabled || !isSpillable()) {
expectedGroups,
groupByTypes,
groupByChannels,
partialAggregationOutputProcessor,
hashChannel,
operatorContext,
maxPartialMemory,
Expand All @@ -421,6 +490,7 @@ else if (step.isOutputPartial() || !spillEnabled || !isSpillable()) {
expectedGroups,
groupByTypes,
groupByChannels,
partialAggregationOutputProcessor,
hashChannel,
operatorContext,
memoryLimitForMerge,
Expand Down Expand Up @@ -584,7 +654,13 @@ private Page getGlobalAggregationOutput()
if (output.isEmpty()) {
return null;
}
return output.build();

Page page = output.build();
if (step.isOutputPartial()) {
page = partialAggregationOutputProcessor.processAggregatedPage(page);
}

return page;
}

private static long calculateDefaultOutputHash(List<Type> groupByChannels, int groupIdChannel, int groupId)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ public class AggregatorFactory
private final Type intermediateType;
private final Type finalType;
private final List<Integer> inputChannels;
private final int intermediateStateChannel;
private final OptionalInt rawInputMaskChannel;
private final OptionalInt maskChannel;
private final boolean spillable;
private final List<Supplier<Object>> lambdaProviders;
Expand All @@ -41,6 +43,7 @@ public AggregatorFactory(
Type intermediateType,
Type finalType,
List<Integer> inputChannels,
OptionalInt rawInputMaskChannel,
OptionalInt maskChannel,
boolean spillable,
List<Supplier<Object>> lambdaProviders)
Expand All @@ -49,12 +52,21 @@ public AggregatorFactory(
this.step = requireNonNull(step, "step is null");
this.intermediateType = requireNonNull(intermediateType, "intermediateType is null");
this.finalType = requireNonNull(finalType, "finalType is null");
this.inputChannels = ImmutableList.copyOf(requireNonNull(inputChannels, "inputChannels is null"));
requireNonNull(inputChannels, "inputChannels is null");
if (step.isInputRaw()) {
intermediateStateChannel = -1;
this.inputChannels = ImmutableList.copyOf(inputChannels);
}
else {
intermediateStateChannel = inputChannels.get(0);
this.inputChannels = ImmutableList.copyOf(inputChannels.subList(1, inputChannels.size()));
}
this.rawInputMaskChannel = requireNonNull(rawInputMaskChannel, "rawInputMaskChannel is null");
this.maskChannel = requireNonNull(maskChannel, "maskChannel is null");
this.spillable = spillable;
this.lambdaProviders = ImmutableList.copyOf(requireNonNull(lambdaProviders, "lambdaProviders is null"));

checkArgument(step.isInputRaw() || inputChannels.size() == 1, "expected 1 input channel for intermediate aggregation");
checkArgument(step.isInputRaw() || intermediateStateChannel != -1, "expected intermediateStateChannel for intermediate aggregation but got %s ", intermediateStateChannel);
}

public Aggregator createAggregator()
Expand All @@ -66,7 +78,8 @@ public Aggregator createAggregator()
else {
accumulator = accumulatorFactory.createIntermediateAccumulator(lambdaProviders);
}
return new Aggregator(accumulator, step, intermediateType, finalType, inputChannels, maskChannel);
List<Integer> aggregatorInputChannels = intermediateStateChannel == -1 ? inputChannels : ImmutableList.of(intermediateStateChannel);
return new Aggregator(accumulator, step, intermediateType, finalType, aggregatorInputChannels, maskChannel);
}

public GroupedAggregator createGroupedAggregator()
Expand All @@ -78,7 +91,7 @@ public GroupedAggregator createGroupedAggregator()
else {
accumulator = accumulatorFactory.createGroupedIntermediateAccumulator(lambdaProviders);
}
return new GroupedAggregator(accumulator, step, intermediateType, finalType, inputChannels, maskChannel);
return new GroupedAggregator(accumulator, step, intermediateType, finalType, inputChannels, intermediateStateChannel, rawInputMaskChannel, maskChannel);
}

public GroupedAggregator createUnspillGroupedAggregator(Step step, int inputChannel)
Expand All @@ -90,11 +103,16 @@ public GroupedAggregator createUnspillGroupedAggregator(Step step, int inputChan
else {
accumulator = accumulatorFactory.createGroupedIntermediateAccumulator(lambdaProviders);
}
return new GroupedAggregator(accumulator, step, intermediateType, finalType, ImmutableList.of(inputChannel), maskChannel);
return new GroupedAggregator(accumulator, step, intermediateType, finalType, ImmutableList.of(inputChannel), inputChannel, OptionalInt.empty(), maskChannel);
}

public boolean isSpillable()
{
return spillable;
}

public OptionalInt getMaskChannel()
{
return maskChannel;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,12 @@
import io.trino.spi.Page;
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.block.ByteArrayBlock;
import io.trino.spi.block.RunLengthEncodedBlock;
import io.trino.spi.type.Type;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.AggregationNode.Step;
import org.apache.commons.math3.util.Pair;

import java.util.List;
import java.util.Optional;
Expand All @@ -36,17 +39,21 @@ public class GroupedAggregator
private final Type intermediateType;
private final Type finalType;
private final int[] inputChannels;
private final int intermediateStateChannel;
private final OptionalInt rawInputMaskChannel;
private final OptionalInt maskChannel;

public GroupedAggregator(GroupedAccumulator accumulator, Step step, Type intermediateType, Type finalType, List<Integer> inputChannels, OptionalInt maskChannel)
public GroupedAggregator(GroupedAccumulator accumulator, Step step, Type intermediateType, Type finalType, List<Integer> inputChannels, int intermediateStateChannel, OptionalInt rawInputMaskChannel, OptionalInt maskChannel)
{
this.accumulator = requireNonNull(accumulator, "accumulator is null");
this.step = requireNonNull(step, "step is null");
this.intermediateType = requireNonNull(intermediateType, "intermediateType is null");
this.finalType = requireNonNull(finalType, "finalType is null");
this.inputChannels = Ints.toArray(requireNonNull(inputChannels, "inputChannels is null"));
this.intermediateStateChannel = intermediateStateChannel;
this.rawInputMaskChannel = requireNonNull(rawInputMaskChannel, "rawInputMaskChannel is null");
this.maskChannel = requireNonNull(maskChannel, "maskChannel is null");
checkArgument(step.isInputRaw() || inputChannels.size() == 1, "expected 1 input channel for intermediate aggregation");
checkArgument(step.isInputRaw() || intermediateStateChannel != -1, "expected intermediateStateChannel for intermediate aggregation but got %s ", intermediateStateChannel);
}

public long getEstimatedSize()
Expand All @@ -68,10 +75,42 @@ public void processPage(GroupByIdBlock groupIds, Page page)
{
if (step.isInputRaw()) {
accumulator.addInput(groupIds, page.getColumns(inputChannels), getMaskBlock(page));
return;
}
else {
accumulator.addIntermediate(groupIds, page.getBlock(inputChannels[0]));

if (rawInputMaskChannel.isEmpty()) {
// process partially aggregated data
accumulator.addIntermediate(groupIds, page.getBlock(intermediateStateChannel));
return;
}
Block rawInputMaskBlock = page.getBlock(rawInputMaskChannel.getAsInt());
if (rawInputMaskBlock instanceof RunLengthEncodedBlock) {
if (rawInputMaskBlock.isNull(0)) {
// process partially aggregated data
accumulator.addIntermediate(groupIds, page.getBlock(intermediateStateChannel));
}
else {
// process raw data
accumulator.addInput(groupIds, page.getColumns(inputChannels), getMaskBlock(page));
}
return;
}

// rawInputMaskBlock has potentially mixed partially aggregated and raw data
Optional<Block> maskBlock = Optional.of((getMaskBlock(page).map(mask -> andMasks(mask, rawInputMaskBlock)).orElse(rawInputMaskBlock)));
accumulator.addInput(groupIds, page.getColumns(inputChannels), maskBlock);
Pair<Block, GroupByIdBlock> filtered = filterByNull(page.getBlock(intermediateStateChannel), groupIds, rawInputMaskBlock);
accumulator.addIntermediate(filtered.getSecond(), filtered.getFirst());
}

private Block andMasks(Block mask1, Block mask2)
{
int positionCount = mask1.getPositionCount();
byte[] mask = new byte[positionCount];
for (int i = 0; i < positionCount; i++) {
mask[i] = (byte) ((!mask1.isNull(i) && mask1.getByte(i, 0) == 1 && !mask2.isNull(i) && mask2.getByte(i, 0) == 1) ? 1 : 0);
}
return new ByteArrayBlock(positionCount, Optional.empty(), mask);
}

private Optional<Block> getMaskBlock(Page page)
Expand Down Expand Up @@ -107,4 +146,24 @@ public Type getSpillType()
{
return intermediateType;
}

private static Pair<Block, GroupByIdBlock> filterByNull(Block block, GroupByIdBlock groupByIdBlock, Block mask)
{
int positions = mask.getPositionCount();

int[] ids = new int[positions];
int next = 0;
for (int i = 0; i < ids.length; ++i) {
if (mask.isNull(i)) {
ids[next++] = i;
}
}

if (next == ids.length) {
return Pair.create(block, groupByIdBlock); // no rows were eliminated by the filter
}
return Pair.create(
block.getPositions(ids, 0, next),
new GroupByIdBlock(groupByIdBlock.getGroupCount(), groupByIdBlock.getPositions(ids, 0, next)));
}
}
Loading

0 comments on commit c5b8233

Please sign in to comment.