diff --git a/core/trino-main/src/main/java/io/trino/operator/HashAggregationOperator.java b/core/trino-main/src/main/java/io/trino/operator/HashAggregationOperator.java index a500a160e102..b7b9862a39f1 100644 --- a/core/trino-main/src/main/java/io/trino/operator/HashAggregationOperator.java +++ b/core/trino-main/src/main/java/io/trino/operator/HashAggregationOperator.java @@ -23,6 +23,7 @@ import io.trino.operator.aggregation.builder.InMemoryHashAggregationBuilder; import io.trino.operator.aggregation.builder.SpillableHashAggregationBuilder; import io.trino.operator.aggregation.partial.PartialAggregationController; +import io.trino.operator.aggregation.partial.PartialAggregationOutputProcessor; import io.trino.operator.aggregation.partial.SkipAggregationBuilder; import io.trino.operator.scalar.CombineHashFunction; import io.trino.spi.Page; @@ -62,6 +63,7 @@ public static class HashAggregationOperatorFactory private final Step step; private final boolean produceDefaultOutput; private final List aggregatorFactories; + private final Optional partialAggregationOutputProcessor; private final Optional hashChannel; private final Optional groupIdChannel; @@ -83,6 +85,8 @@ public HashAggregationOperatorFactory( PlanNodeId planNodeId, List groupByTypes, List groupByChannels, + List aggregationRawInputTypes, + List aggregationInputChannels, List globalAggregationGroupIds, Step step, List aggregatorFactories, @@ -98,6 +102,14 @@ public HashAggregationOperatorFactory( planNodeId, groupByTypes, groupByChannels, + step.isOutputPartial() ? + Optional.of(new PartialAggregationOutputProcessor( + groupByChannels, + hashChannel, + aggregatorFactories, + aggregationRawInputTypes, + aggregationInputChannels)) : + Optional.empty(), globalAggregationGroupIds, step, false, @@ -122,6 +134,8 @@ public HashAggregationOperatorFactory( PlanNodeId planNodeId, List groupByTypes, List groupByChannels, + List aggregationRawInputTypes, + List aggregationInputChannels, List globalAggregationGroupIds, Step step, boolean produceDefaultOutput, @@ -141,6 +155,14 @@ public HashAggregationOperatorFactory( planNodeId, groupByTypes, groupByChannels, + step.isOutputPartial() ? + Optional.of(new PartialAggregationOutputProcessor( + groupByChannels, + hashChannel, + aggregatorFactories, + aggregationRawInputTypes, + aggregationInputChannels)) : + Optional.empty(), globalAggregationGroupIds, step, produceDefaultOutput, @@ -164,6 +186,7 @@ public HashAggregationOperatorFactory( PlanNodeId planNodeId, List groupByTypes, List groupByChannels, + Optional partialAggregationOutputProcessor, List globalAggregationGroupIds, Step step, boolean produceDefaultOutput, @@ -186,6 +209,10 @@ public HashAggregationOperatorFactory( this.groupIdChannel = requireNonNull(groupIdChannel, "groupIdChannel is null"); this.groupByTypes = ImmutableList.copyOf(groupByTypes); this.groupByChannels = ImmutableList.copyOf(groupByChannels); + if (step.isOutputPartial()) { + checkArgument(partialAggregationOutputProcessor.isPresent(), "partialAggregationOutputProcessor must be present in the partial step"); + } + this.partialAggregationOutputProcessor = requireNonNull(partialAggregationOutputProcessor, "partialAggregationOutputProcessor is null"); this.globalAggregationGroupIds = ImmutableList.copyOf(globalAggregationGroupIds); this.step = step; this.produceDefaultOutput = produceDefaultOutput; @@ -211,6 +238,7 @@ public Operator createOperator(DriverContext driverContext) operatorContext, groupByTypes, groupByChannels, + partialAggregationOutputProcessor, globalAggregationGroupIds, step, produceDefaultOutput, @@ -243,6 +271,7 @@ public OperatorFactory duplicate() planNodeId, groupByTypes, groupByChannels, + partialAggregationOutputProcessor, globalAggregationGroupIds, step, produceDefaultOutput, @@ -265,6 +294,7 @@ public OperatorFactory duplicate() private final Optional partialAggregationController; private final List groupByTypes; private final List groupByChannels; + private final Optional partialAggregationOutputProcessor; private final List globalAggregationGroupIds; private final Step step; private final boolean produceDefaultOutput; @@ -299,6 +329,7 @@ private HashAggregationOperator( OperatorContext operatorContext, List groupByTypes, List groupByChannels, + Optional partialAggregationOutputProcessor, List globalAggregationGroupIds, Step step, boolean produceDefaultOutput, @@ -321,9 +352,12 @@ private HashAggregationOperator( requireNonNull(aggregatorFactories, "aggregatorFactories is null"); requireNonNull(operatorContext, "operatorContext is null"); checkArgument(partialAggregationController.isEmpty() || step.isOutputPartial(), "partialAggregationController should be present only for partial aggregation"); - + if (step.isOutputPartial()) { + checkArgument(partialAggregationOutputProcessor.isPresent(), "partialAggregationOutputProcessor must be present in the partial step"); + } this.groupByTypes = ImmutableList.copyOf(groupByTypes); this.groupByChannels = ImmutableList.copyOf(groupByChannels); + this.partialAggregationOutputProcessor = requireNonNull(partialAggregationOutputProcessor, "partialAggregationOutputProcessor is null"); this.globalAggregationGroupIds = ImmutableList.copyOf(globalAggregationGroupIds); this.aggregatorFactories = ImmutableList.copyOf(aggregatorFactories); this.hashChannel = requireNonNull(hashChannel, "hashChannel is null"); @@ -390,7 +424,7 @@ public void addInput(Page page) .map(PartialAggregationController::isPartialAggregationDisabled) .orElse(false); if (step.isOutputPartial() && partialAggregationDisabled) { - aggregationBuilder = new SkipAggregationBuilder(groupByChannels, hashChannel, aggregatorFactories, memoryContext); + aggregationBuilder = new SkipAggregationBuilder(partialAggregationOutputProcessor.get(), memoryContext); } else if (step.isOutputPartial() || !spillEnabled || !isSpillable()) { // TODO: We ignore spillEnabled here if any aggregate has ORDER BY clause or DISTINCT because they are not yet implemented for spilling. @@ -400,6 +434,7 @@ else if (step.isOutputPartial() || !spillEnabled || !isSpillable()) { expectedGroups, groupByTypes, groupByChannels, + partialAggregationOutputProcessor, hashChannel, operatorContext, maxPartialMemory, @@ -421,6 +456,7 @@ else if (step.isOutputPartial() || !spillEnabled || !isSpillable()) { expectedGroups, groupByTypes, groupByChannels, + partialAggregationOutputProcessor, hashChannel, operatorContext, memoryLimitForMerge, @@ -584,7 +620,13 @@ private Page getGlobalAggregationOutput() if (output.isEmpty()) { return null; } - return output.build(); + + Page page = output.build(); + if (step.isOutputPartial()) { + page = partialAggregationOutputProcessor.get().processAggregatedPage(page); + } + + return page; } private static long calculateDefaultOutputHash(List groupByChannels, int groupIdChannel, int groupId) diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregatorFactory.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregatorFactory.java index 968162f39347..ac164b70116c 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregatorFactory.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregatorFactory.java @@ -21,7 +21,6 @@ import java.util.OptionalInt; import java.util.function.Supplier; -import static com.google.common.base.Preconditions.checkArgument; import static java.util.Objects.requireNonNull; public class AggregatorFactory @@ -30,7 +29,9 @@ public class AggregatorFactory private final Step step; private final Type intermediateType; private final Type finalType; - private final List inputChannels; + private final List rawInputChannels; + private final OptionalInt intermediateStateChannel; + private final OptionalInt rawInputMaskChannel; private final OptionalInt maskChannel; private final boolean spillable; private final List> lambdaProviders; @@ -41,6 +42,7 @@ public AggregatorFactory( Type intermediateType, Type finalType, List inputChannels, + OptionalInt rawInputMaskChannel, OptionalInt maskChannel, boolean spillable, List> lambdaProviders) @@ -49,12 +51,19 @@ public AggregatorFactory( this.step = requireNonNull(step, "step is null"); this.intermediateType = requireNonNull(intermediateType, "intermediateType is null"); this.finalType = requireNonNull(finalType, "finalType is null"); - this.inputChannels = ImmutableList.copyOf(requireNonNull(inputChannels, "inputChannels is null")); + requireNonNull(inputChannels, "inputChannels is null"); + if (step.isInputRaw()) { + intermediateStateChannel = OptionalInt.empty(); + this.rawInputChannels = ImmutableList.copyOf(inputChannels); + } + else { + intermediateStateChannel = OptionalInt.of(inputChannels.get(0)); + this.rawInputChannels = ImmutableList.copyOf(inputChannels.subList(1, inputChannels.size())); + } + this.rawInputMaskChannel = requireNonNull(rawInputMaskChannel, "rawInputMaskChannel is null"); this.maskChannel = requireNonNull(maskChannel, "maskChannel is null"); this.spillable = spillable; this.lambdaProviders = ImmutableList.copyOf(requireNonNull(lambdaProviders, "lambdaProviders is null")); - - checkArgument(step.isInputRaw() || inputChannels.size() == 1, "expected 1 input channel for intermediate aggregation"); } public Aggregator createAggregator() @@ -66,6 +75,7 @@ public Aggregator createAggregator() else { accumulator = accumulatorFactory.createIntermediateAccumulator(lambdaProviders); } + List inputChannels = intermediateStateChannel.isEmpty() ? rawInputChannels : ImmutableList.of(intermediateStateChannel.getAsInt()); return new Aggregator(accumulator, step, intermediateType, finalType, inputChannels, maskChannel); } @@ -78,7 +88,7 @@ public GroupedAggregator createGroupedAggregator() else { accumulator = accumulatorFactory.createGroupedIntermediateAccumulator(lambdaProviders); } - return new GroupedAggregator(accumulator, step, intermediateType, finalType, inputChannels, maskChannel); + return new GroupedAggregator(accumulator, step, intermediateType, finalType, rawInputChannels, intermediateStateChannel, rawInputMaskChannel, maskChannel); } public GroupedAggregator createUnspillGroupedAggregator(Step step, int inputChannel) @@ -90,11 +100,21 @@ public GroupedAggregator createUnspillGroupedAggregator(Step step, int inputChan else { accumulator = accumulatorFactory.createGroupedIntermediateAccumulator(lambdaProviders); } - return new GroupedAggregator(accumulator, step, intermediateType, finalType, ImmutableList.of(inputChannel), maskChannel); + return new GroupedAggregator(accumulator, step, intermediateType, finalType, ImmutableList.of(inputChannel), OptionalInt.of(inputChannel), OptionalInt.empty(), maskChannel); } public boolean isSpillable() { return spillable; } + + public OptionalInt getMaskChannel() + { + return maskChannel; + } + + public Type getIntermediateType() + { + return intermediateType; + } } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/GroupedAggregator.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/GroupedAggregator.java index 05f43e2b38b3..7fb7ced5c1dc 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/GroupedAggregator.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/GroupedAggregator.java @@ -18,9 +18,12 @@ import io.trino.spi.Page; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ByteArrayBlock; +import io.trino.spi.block.RunLengthEncodedBlock; import io.trino.spi.type.Type; import io.trino.sql.planner.plan.AggregationNode; import io.trino.sql.planner.plan.AggregationNode.Step; +import it.unimi.dsi.fastutil.ints.IntArrayList; import java.util.List; import java.util.Optional; @@ -35,18 +38,30 @@ public class GroupedAggregator private AggregationNode.Step step; private final Type intermediateType; private final Type finalType; - private final int[] inputChannels; + private final int[] rawInputChannels; + private final OptionalInt intermediateStateChannel; + private final OptionalInt rawInputMaskChannel; private final OptionalInt maskChannel; - public GroupedAggregator(GroupedAccumulator accumulator, Step step, Type intermediateType, Type finalType, List inputChannels, OptionalInt maskChannel) + public GroupedAggregator( + GroupedAccumulator accumulator, + Step step, + Type intermediateType, + Type finalType, + List rawInputChannels, + OptionalInt intermediateStateChannel, + OptionalInt rawInputMaskChannel, + OptionalInt maskChannel) { this.accumulator = requireNonNull(accumulator, "accumulator is null"); this.step = requireNonNull(step, "step is null"); this.intermediateType = requireNonNull(intermediateType, "intermediateType is null"); this.finalType = requireNonNull(finalType, "finalType is null"); - this.inputChannels = Ints.toArray(requireNonNull(inputChannels, "inputChannels is null")); + this.rawInputChannels = Ints.toArray(requireNonNull(rawInputChannels, "inputChannels is null")); + this.intermediateStateChannel = requireNonNull(intermediateStateChannel, "intermediateStateChannel is null"); + this.rawInputMaskChannel = requireNonNull(rawInputMaskChannel, "rawInputMaskChannel is null"); this.maskChannel = requireNonNull(maskChannel, "maskChannel is null"); - checkArgument(step.isInputRaw() || inputChannels.size() == 1, "expected 1 input channel for intermediate aggregation"); + checkArgument(step.isInputRaw() || intermediateStateChannel.isPresent(), "expected intermediateStateChannel for intermediate aggregation but got %s ", intermediateStateChannel); } public long getEstimatedSize() @@ -67,11 +82,55 @@ public Type getType() public void processPage(GroupByIdBlock groupIds, Page page) { if (step.isInputRaw()) { - accumulator.addInput(groupIds, page.getColumns(inputChannels), getMaskBlock(page)); + accumulator.addInput(groupIds, page.getColumns(rawInputChannels), getMaskBlock(page)); + return; } - else { - accumulator.addIntermediate(groupIds, page.getBlock(inputChannels[0])); + + if (rawInputMaskChannel.isEmpty()) { + // process partially aggregated data + accumulator.addIntermediate(groupIds, page.getBlock(intermediateStateChannel.getAsInt())); + return; + } + Block rawInputMaskBlock = page.getBlock(rawInputMaskChannel.getAsInt()); + if (rawInputMaskBlock instanceof RunLengthEncodedBlock) { + if (rawInputMaskBlock.isNull(0)) { + // process partially aggregated data + accumulator.addIntermediate(groupIds, page.getBlock(intermediateStateChannel.getAsInt())); + } + else { + // process raw data + accumulator.addInput(groupIds, page.getColumns(rawInputChannels), getMaskBlock(page)); + } + return; } + + // rawInputMaskBlock has potentially mixed partially aggregated and raw data + Block maskBlock = getMaskBlock(page) + .map(mask -> andMasks(mask, rawInputMaskBlock)) + .orElse(rawInputMaskBlock); + accumulator.addInput(groupIds, page.getColumns(rawInputChannels), Optional.of(maskBlock)); + IntArrayList intermediatePositions = filterByNull(rawInputMaskBlock); + Block intermediateStateBlock = page.getBlock(intermediateStateChannel.getAsInt()); + + if (intermediatePositions.size() != rawInputMaskBlock.getPositionCount()) { + // some rows were eliminated by the filter + intermediateStateBlock = intermediateStateBlock.getPositions(intermediatePositions.elements(), 0, intermediatePositions.size()); + groupIds = new GroupByIdBlock( + groupIds.getGroupCount(), + groupIds.getPositions(intermediatePositions.elements(), 0, intermediatePositions.size())); + } + + accumulator.addIntermediate(groupIds, intermediateStateBlock); + } + + private Block andMasks(Block mask1, Block mask2) + { + int positionCount = mask1.getPositionCount(); + byte[] mask = new byte[positionCount]; + for (int i = 0; i < positionCount; i++) { + mask[i] = (byte) ((!mask1.isNull(i) && mask1.getByte(i, 0) == 1 && !mask2.isNull(i) && mask2.getByte(i, 0) == 1) ? 1 : 0); + } + return new ByteArrayBlock(positionCount, Optional.empty(), mask); } private Optional getMaskBlock(Page page) @@ -107,4 +166,18 @@ public Type getSpillType() { return intermediateType; } + + private static IntArrayList filterByNull(Block mask) + { + int positions = mask.getPositionCount(); + + IntArrayList ids = new IntArrayList(positions); + for (int i = 0; i < positions; ++i) { + if (mask.isNull(i)) { + ids.add(i); + } + } + + return ids; + } } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/builder/InMemoryHashAggregationBuilder.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/builder/InMemoryHashAggregationBuilder.java index d4cd20cd0944..3722acb0e5e7 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/builder/InMemoryHashAggregationBuilder.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/builder/InMemoryHashAggregationBuilder.java @@ -29,6 +29,7 @@ import io.trino.operator.WorkProcessor.ProcessState; import io.trino.operator.aggregation.AggregatorFactory; import io.trino.operator.aggregation.GroupedAggregator; +import io.trino.operator.aggregation.partial.PartialAggregationOutputProcessor; import io.trino.spi.Page; import io.trino.spi.PageBuilder; import io.trino.spi.block.BlockBuilder; @@ -45,6 +46,7 @@ import java.util.Optional; import java.util.OptionalLong; +import static com.google.common.base.Preconditions.checkArgument; import static io.trino.operator.GroupByHash.createGroupByHash; import static io.trino.spi.type.BigintType.BIGINT; import static java.util.Objects.requireNonNull; @@ -57,6 +59,7 @@ public class InMemoryHashAggregationBuilder private final boolean partial; private final OptionalLong maxPartialMemory; private final UpdateMemory updateMemory; + private final Optional partialAggregationOutputProcessor; private boolean full; @@ -66,6 +69,7 @@ public InMemoryHashAggregationBuilder( int expectedGroups, List groupByTypes, List groupByChannels, + Optional partialAggregationOutputProcessor, Optional hashChannel, OperatorContext operatorContext, Optional maxPartialMemory, @@ -78,6 +82,7 @@ public InMemoryHashAggregationBuilder( expectedGroups, groupByTypes, groupByChannels, + partialAggregationOutputProcessor, hashChannel, operatorContext, maxPartialMemory, @@ -93,6 +98,7 @@ public InMemoryHashAggregationBuilder( int expectedGroups, List groupByTypes, List groupByChannels, + Optional partialAggregationOutputProcessor, Optional hashChannel, OperatorContext operatorContext, Optional maxPartialMemory, @@ -113,7 +119,10 @@ public InMemoryHashAggregationBuilder( this.partial = step.isOutputPartial(); this.maxPartialMemory = maxPartialMemory.map(dataSize -> OptionalLong.of(dataSize.toBytes())).orElseGet(OptionalLong::empty); this.updateMemory = requireNonNull(updateMemory, "updateMemory is null"); - + this.partialAggregationOutputProcessor = requireNonNull(partialAggregationOutputProcessor, "partialAggregationOutputProcessor is null"); + if (partial) { + checkArgument(partialAggregationOutputProcessor.isPresent(), "partialAggregationOutputProcessor must be present in the partial step"); + } // wrapper each function with an aggregator ImmutableList.Builder builder = ImmutableList.builder(); requireNonNull(aggregatorFactories, "aggregatorFactories is null"); @@ -290,7 +299,12 @@ private WorkProcessor buildResult(IntIterator groupIds) } } - return ProcessState.ofResult(pageBuilder.build()); + Page page = pageBuilder.build(); + if (partial) { + page = partialAggregationOutputProcessor.get().processAggregatedPage(page); + } + + return ProcessState.ofResult(page); }); } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/builder/MergingHashAggregationBuilder.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/builder/MergingHashAggregationBuilder.java index ec53d872f881..f3810374a0ba 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/builder/MergingHashAggregationBuilder.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/builder/MergingHashAggregationBuilder.java @@ -22,6 +22,7 @@ import io.trino.operator.WorkProcessor.Transformation; import io.trino.operator.WorkProcessor.TransformationState; import io.trino.operator.aggregation.AggregatorFactory; +import io.trino.operator.aggregation.partial.PartialAggregationOutputProcessor; import io.trino.spi.Page; import io.trino.spi.type.Type; import io.trino.sql.gen.JoinCompiler; @@ -46,6 +47,7 @@ public class MergingHashAggregationBuilder private final WorkProcessor sortedPages; private InMemoryHashAggregationBuilder hashAggregationBuilder; private final List groupByTypes; + private final Optional partialAggregationOutputProcessor; private final LocalMemoryContext memoryContext; private final long memoryLimitForMerge; private final int overwriteIntermediateChannelOffset; @@ -57,6 +59,7 @@ public MergingHashAggregationBuilder( AggregationNode.Step step, int expectedGroups, List groupByTypes, + Optional partialAggregationOutputProcessor, Optional hashChannel, OperatorContext operatorContext, WorkProcessor sortedPages, @@ -79,6 +82,7 @@ public MergingHashAggregationBuilder( this.operatorContext = operatorContext; this.sortedPages = sortedPages; this.groupByTypes = groupByTypes; + this.partialAggregationOutputProcessor = partialAggregationOutputProcessor; this.memoryContext = aggregatedMemoryContext.newLocalMemoryContext(MergingHashAggregationBuilder.class.getSimpleName()); this.memoryLimitForMerge = memoryLimitForMerge; this.overwriteIntermediateChannelOffset = overwriteIntermediateChannelOffset; @@ -149,6 +153,7 @@ private void rebuildHashAggregationBuilder() expectedGroups, groupByTypes, groupByPartialChannels, + partialAggregationOutputProcessor, hashChannel, operatorContext, Optional.of(DataSize.succinctBytes(0)), diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/builder/SpillableHashAggregationBuilder.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/builder/SpillableHashAggregationBuilder.java index da7ad0f8461a..3404aeffd4c2 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/builder/SpillableHashAggregationBuilder.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/builder/SpillableHashAggregationBuilder.java @@ -24,6 +24,7 @@ import io.trino.operator.Work; import io.trino.operator.WorkProcessor; import io.trino.operator.aggregation.AggregatorFactory; +import io.trino.operator.aggregation.partial.PartialAggregationOutputProcessor; import io.trino.spi.Page; import io.trino.spi.type.Type; import io.trino.spiller.Spiller; @@ -54,6 +55,7 @@ public class SpillableHashAggregationBuilder private final int expectedGroups; private final List groupByTypes; private final List groupByChannels; + private final Optional partialAggregationOutputProcessor; private final Optional hashChannel; private final OperatorContext operatorContext; private final LocalMemoryContext localUserMemoryContext; @@ -80,6 +82,7 @@ public SpillableHashAggregationBuilder( int expectedGroups, List groupByTypes, List groupByChannels, + Optional partialAggregationOutputProcessor, Optional hashChannel, OperatorContext operatorContext, DataSize memoryLimitForMerge, @@ -93,6 +96,7 @@ public SpillableHashAggregationBuilder( this.expectedGroups = expectedGroups; this.groupByTypes = groupByTypes; this.groupByChannels = groupByChannels; + this.partialAggregationOutputProcessor = partialAggregationOutputProcessor; this.hashChannel = hashChannel; this.operatorContext = operatorContext; this.localUserMemoryContext = operatorContext.localUserMemoryContext(); @@ -310,6 +314,7 @@ private WorkProcessor mergeSortedPages(WorkProcessor sortedPages, lo step, expectedGroups, groupByTypes, + partialAggregationOutputProcessor, hashChannel, operatorContext, sortedPages, @@ -336,6 +341,7 @@ private void rebuildHashAggregationBuilder() expectedGroups, groupByTypes, groupByChannels, + partialAggregationOutputProcessor, hashChannel, operatorContext, Optional.of(DataSize.succinctBytes(0)), diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/partial/PartialAggregationOutputProcessor.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/partial/PartialAggregationOutputProcessor.java new file mode 100644 index 000000000000..bdb3e79909ae --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/partial/PartialAggregationOutputProcessor.java @@ -0,0 +1,130 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator.aggregation.partial; + +import com.google.common.collect.ImmutableList; +import com.google.common.primitives.Ints; +import io.trino.operator.aggregation.AggregatorFactory; +import io.trino.spi.Page; +import io.trino.spi.block.Block; +import io.trino.spi.block.RunLengthEncodedBlock; +import io.trino.spi.type.Type; + +import java.util.List; +import java.util.Optional; +import java.util.OptionalInt; + +import static com.google.common.collect.ImmutableList.toImmutableList; +import static io.trino.spi.type.BooleanType.BOOLEAN; +import static java.util.Objects.requireNonNull; + +/** + * Creates final output of the partial aggregation step based on either aggregated or raw input. + *

+ * Partial aggregation step output has additional channels used to send + * raw input data when the partial aggregation is disabled. + * Those channels need to be present in the output, albeit empty, even if the + * partial aggregation is enabled. + *

+ * The additional channels are: + * - mask channels used by any aggregation + * - every aggregation input channels that are not already in hash channels + * - a boolean rawInputMask channel that contains decision which input, aggregated or raw, should be used for a given position. + */ +public class PartialAggregationOutputProcessor +{ + private final List aggregatorIntermediateOutputTypes; + private final List aggregationRawInputTypes; + private final int[] hashChannels; + private final int[] aggregationRawInputChannels; + private final int[] maskChannels; + private final int finalBlockCount; + + public PartialAggregationOutputProcessor( + List groupByChannels, + Optional inputHashChannel, + List aggregatorFactories, + List aggregationRawInputTypes, + List aggregationRawInputChannels) + { + this.aggregationRawInputTypes = ImmutableList.copyOf(aggregationRawInputTypes); + + this.aggregatorIntermediateOutputTypes = requireNonNull(aggregatorFactories, "aggregatorFactories is null") + .stream() + .map(AggregatorFactory::getIntermediateType) + .collect(toImmutableList()); + this.aggregationRawInputChannels = Ints.toArray(aggregationRawInputChannels); + this.maskChannels = Ints.toArray(aggregatorFactories + .stream() + .map(AggregatorFactory::getMaskChannel) + .flatMapToInt(OptionalInt::stream) + .boxed() + .distinct() + .collect(toImmutableList())); + this.hashChannels = new int[groupByChannels.size() + (inputHashChannel.isPresent() ? 1 : 0)]; + for (int i = 0; i < groupByChannels.size(); i++) { + hashChannels[i] = groupByChannels.get(i); + } + inputHashChannel.ifPresent(channelIndex -> hashChannels[groupByChannels.size()] = channelIndex); + finalBlockCount = hashChannels.length + aggregatorIntermediateOutputTypes.size() + maskChannels.length + aggregationRawInputChannels.size() + 1; + } + + public Page processAggregatedPage(Page page) + { + Block[] finalPage = new Block[finalBlockCount]; + int blockOffset = 0; + for (int i = 0; i < page.getChannelCount(); i++, blockOffset++) { + finalPage[blockOffset] = page.getBlock(i); + } + int positionCount = page.getPositionCount(); + + // mask channels + for (int i = 0; i < maskChannels.length; i++, blockOffset++) { + finalPage[blockOffset] = RunLengthEncodedBlock.create(BOOLEAN, null, positionCount); + } + // aggregation raw inputs + for (int i = 0; i < aggregationRawInputTypes.size(); i++, blockOffset++) { + finalPage[blockOffset] = RunLengthEncodedBlock.create(aggregationRawInputTypes.get(i), null, positionCount); + } + + // use raw input mask channel + finalPage[blockOffset] = RunLengthEncodedBlock.create(BOOLEAN, null, positionCount); + return new Page(positionCount, finalPage); + } + + public Page processRawInputPage(Page page) + { + Block[] finalPage = new Block[finalBlockCount]; + int blockOffset = 0; + // raw input hash channels + for (int i = 0; i < hashChannels.length; i++, blockOffset++) { + finalPage[blockOffset] = page.getBlock(hashChannels[i]); + } + // aggregator state channels + for (int i = 0; i < aggregatorIntermediateOutputTypes.size(); i++, blockOffset++) { + finalPage[blockOffset] = RunLengthEncodedBlock.create(aggregatorIntermediateOutputTypes.get(i), null, page.getPositionCount()); + } + // mask channels + for (int i = 0; i < maskChannels.length; i++, blockOffset++) { + finalPage[blockOffset] = page.getBlock(maskChannels[i]); + } + // aggregation raw inputs + for (int i = 0; i < aggregationRawInputChannels.length; i++, blockOffset++) { + finalPage[blockOffset] = page.getBlock(aggregationRawInputChannels[i]); + } + // use raw input mask channel + finalPage[blockOffset] = RunLengthEncodedBlock.create(BOOLEAN, true, page.getPositionCount()); + return new Page(page.getPositionCount(), finalPage); + } +} diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/partial/SkipAggregationBuilder.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/partial/SkipAggregationBuilder.java index 3e9f00b9ad8d..f0cbe0a9d670 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/partial/SkipAggregationBuilder.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/partial/SkipAggregationBuilder.java @@ -16,25 +16,15 @@ import com.google.common.util.concurrent.ListenableFuture; import io.trino.memory.context.LocalMemoryContext; import io.trino.operator.CompletedWork; -import io.trino.operator.GroupByIdBlock; import io.trino.operator.HashCollisionsCounter; import io.trino.operator.Work; import io.trino.operator.WorkProcessor; -import io.trino.operator.aggregation.AggregatorFactory; -import io.trino.operator.aggregation.GroupedAggregator; import io.trino.operator.aggregation.builder.HashAggregationBuilder; import io.trino.spi.Page; -import io.trino.spi.block.Block; -import io.trino.spi.block.BlockBuilder; -import io.trino.spi.block.LongArrayBlock; import javax.annotation.Nullable; -import java.util.List; -import java.util.Optional; - import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.collect.ImmutableList.toImmutableList; import static java.util.Objects.requireNonNull; /** @@ -45,28 +35,17 @@ public class SkipAggregationBuilder implements HashAggregationBuilder { + private final PartialAggregationOutputProcessor partialAggregationOutputProcessor; private final LocalMemoryContext memoryContext; - private final List groupedAggregators; @Nullable private Page currentPage; - private final int[] hashChannels; public SkipAggregationBuilder( - List groupByChannels, - Optional inputHashChannel, - List aggregatorFactories, + PartialAggregationOutputProcessor partialAggregationOutputProcessor, LocalMemoryContext memoryContext) { + this.partialAggregationOutputProcessor = requireNonNull(partialAggregationOutputProcessor, "partialAggregationOutputProcessor is null"); this.memoryContext = requireNonNull(memoryContext, "memoryContext is null"); - this.groupedAggregators = requireNonNull(aggregatorFactories, "aggregatorFactories is null") - .stream() - .map(AggregatorFactory::createGroupedAggregator) - .collect(toImmutableList()); - this.hashChannels = new int[groupByChannels.size() + (inputHashChannel.isPresent() ? 1 : 0)]; - for (int i = 0; i < groupByChannels.size(); i++) { - hashChannels[i] = groupByChannels.get(i); - } - inputHashChannel.ifPresent(channelIndex -> hashChannels[groupByChannels.size()] = channelIndex); } @Override @@ -84,7 +63,7 @@ public WorkProcessor buildResult() return WorkProcessor.of(); } - Page result = buildOutputPage(currentPage); + Page result = partialAggregationOutputProcessor.processRawInputPage(currentPage); currentPage = null; return WorkProcessor.of(result); } @@ -125,66 +104,4 @@ public void finishMemoryRevoke() { throw new UnsupportedOperationException("finishMemoryRevoke not supported for SkipAggregationBuilder"); } - - private Page buildOutputPage(Page page) - { - populateInitialAccumulatorState(page); - - BlockBuilder[] outputBuilders = serializeAccumulatorState(page.getPositionCount()); - - return constructOutputPage(page, outputBuilders); - } - - private void populateInitialAccumulatorState(Page page) - { - GroupByIdBlock groupByIdBlock = getGroupByIdBlock(page.getPositionCount()); - for (GroupedAggregator groupedAggregator : groupedAggregators) { - groupedAggregator.processPage(groupByIdBlock, page); - } - } - - private GroupByIdBlock getGroupByIdBlock(int positionCount) - { - return new GroupByIdBlock( - positionCount, - new LongArrayBlock(positionCount, Optional.empty(), consecutive(positionCount))); - } - - private BlockBuilder[] serializeAccumulatorState(int positionCount) - { - BlockBuilder[] outputBuilders = new BlockBuilder[groupedAggregators.size()]; - for (int i = 0; i < outputBuilders.length; i++) { - outputBuilders[i] = groupedAggregators.get(i).getType().createBlockBuilder(null, positionCount); - } - - for (int position = 0; position < positionCount; position++) { - for (int i = 0; i < groupedAggregators.size(); i++) { - GroupedAggregator groupedAggregator = groupedAggregators.get(i); - BlockBuilder output = outputBuilders[i]; - groupedAggregator.evaluate(position, output); - } - } - return outputBuilders; - } - - private Page constructOutputPage(Page page, BlockBuilder[] outputBuilders) - { - Block[] outputBlocks = new Block[hashChannels.length + outputBuilders.length]; - for (int i = 0; i < hashChannels.length; i++) { - outputBlocks[i] = page.getBlock(hashChannels[i]); - } - for (int i = 0; i < outputBuilders.length; i++) { - outputBlocks[hashChannels.length + i] = outputBuilders[i].build(); - } - return new Page(page.getPositionCount(), outputBlocks); - } - - private static long[] consecutive(int positionCount) - { - long[] longs = new long[positionCount]; - for (int i = 0; i < positionCount; i++) { - longs[i] = i; - } - return longs; - } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java index e5cbeeebdc3c..e305e793ff60 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java @@ -3184,6 +3184,7 @@ public PhysicalOperation visitTableWriter(TableWriterNode node, LocalExecutionPl PARTIAL, Optional.empty(), Optional.empty(), + Optional.empty(), source, false, false, @@ -3266,6 +3267,7 @@ public PhysicalOperation visitTableFinish(TableFinishNode node, LocalExecutionPl FINAL, Optional.empty(), Optional.empty(), + Optional.empty(), source, false, false, @@ -3590,7 +3592,8 @@ private List getSymbolTypes(List symbols, TypeProvider types) private AggregatorFactory buildAggregatorFactory( PhysicalOperation source, Aggregation aggregation, - Step step) + Step step, + OptionalInt rawInputMaskChannel) { List argumentChannels = new ArrayList<>(); for (Expression argument : aggregation.getArguments()) { @@ -3680,6 +3683,7 @@ private AggregatorFactory buildAggregatorFactory( intermediateType, finalType, argumentChannels, + rawInputMaskChannel, maskChannel, !aggregation.isDistinct() && aggregation.getOrderingScheme().isEmpty(), lambdaProviders); @@ -3767,7 +3771,7 @@ private AggregationOperatorFactory createAggregationOperatorFactory( for (Map.Entry entry : aggregations.entrySet()) { Symbol symbol = entry.getKey(); Aggregation aggregation = entry.getValue(); - aggregatorFactories.add(buildAggregatorFactory(source, aggregation, step)); + aggregatorFactories.add(buildAggregatorFactory(source, aggregation, step, OptionalInt.empty())); outputMappings.put(symbol, outputChannel); // one aggregation per channel outputChannel++; } @@ -3790,6 +3794,7 @@ private PhysicalOperation planGroupByAggregation( node.getStep(), node.getHashSymbol(), node.getGroupIdSymbol(), + node.getRawInputMaskSymbol(), source, node.hasDefaultOutput(), spillEnabled, @@ -3811,6 +3816,7 @@ private OperatorFactory createHashAggregationOperatorFactory( Step step, Optional hashSymbol, Optional groupIdSymbol, + Optional rawInputMaskSymbol, PhysicalOperation source, boolean hasDefaultOutput, boolean spillEnabled, @@ -3818,17 +3824,21 @@ private OperatorFactory createHashAggregationOperatorFactory( DataSize unspillMemoryLimit, LocalExecutionPlanContext context, int startOutputChannel, - ImmutableMap.Builder outputMappings, + ImmutableMap.Builder outputMappingsBuilder, int expectedGroups, Optional maxPartialAggregationMemorySize) { List aggregationOutputSymbols = new ArrayList<>(); List aggregatorFactories = new ArrayList<>(); + OptionalInt rawInputMaskChannel = rawInputMaskSymbol + .filter(symbol -> !step.isInputRaw()) + .map(symbol -> OptionalInt.of(source.getLayout().get(symbol))) + .orElseGet(OptionalInt::empty); for (Map.Entry entry : aggregations.entrySet()) { Symbol symbol = entry.getKey(); Aggregation aggregation = entry.getValue(); - aggregatorFactories.add(buildAggregatorFactory(source, aggregation, step)); + aggregatorFactories.add(buildAggregatorFactory(source, aggregation, step, rawInputMaskChannel)); aggregationOutputSymbols.add(symbol); } @@ -3836,7 +3846,7 @@ private OperatorFactory createHashAggregationOperatorFactory( int channel = startOutputChannel; Optional groupIdChannel = Optional.empty(); for (Symbol symbol : groupBySymbols) { - outputMappings.put(symbol, channel); + outputMappingsBuilder.put(symbol, channel); if (groupIdSymbol.isPresent() && groupIdSymbol.get().equals(symbol)) { groupIdChannel = Optional.of(channel); } @@ -3845,12 +3855,37 @@ private OperatorFactory createHashAggregationOperatorFactory( // hashChannel follows the group by channels if (hashSymbol.isPresent()) { - outputMappings.put(hashSymbol.get(), channel++); + outputMappingsBuilder.put(hashSymbol.get(), channel++); } // aggregations go in following channels for (Symbol symbol : aggregationOutputSymbols) { - outputMappings.put(symbol, channel); + outputMappingsBuilder.put(symbol, channel); + channel++; + } + Map outputMappings = outputMappingsBuilder.buildOrThrow(); + List aggregationRawInputs = rawInputMaskSymbol.isPresent() ? + aggregations.values().stream() + .flatMap(Aggregation::getRawInputs) + .filter(input -> !outputMappings.containsKey(input)) + .distinct() + .collect(toImmutableList()) + : ImmutableList.of(); + if (step.isOutputPartial() && rawInputMaskSymbol.isPresent()) { + // add mask channels + for (Symbol symbol : aggregations.values().stream() + .map(Aggregation::getMask) + .flatMap(Optional::stream) + .collect(toImmutableSet())) { + outputMappingsBuilder.put(symbol, channel); + channel++; + } + // add inputs to the aggregations to be used by adaptive partial aggregation + for (Symbol symbol : aggregationRawInputs) { + outputMappingsBuilder.put(symbol, channel); + channel++; + } + outputMappingsBuilder.put(rawInputMaskSymbol.orElseThrow(), channel); channel++; } @@ -3870,12 +3905,18 @@ private OperatorFactory createHashAggregationOperatorFactory( joinCompiler); } else { + List aggregationRawInputChannels = getChannelsForSymbols(aggregationRawInputs, source.getLayout()); + List aggregationRawInputTypes = aggregationRawInputChannels.stream() + .map(entry -> source.getTypes().get(entry)) + .collect(toImmutableList()); Optional hashChannel = hashSymbol.map(channelGetter(source)); return new HashAggregationOperatorFactory( context.getNextOperatorId(), planNodeId, groupByTypes, groupByChannels, + aggregationRawInputTypes, + aggregationRawInputChannels, ImmutableList.copyOf(globalGroupingSets), step, hasDefaultOutput, diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/QueryPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/QueryPlanner.java index a207904874e6..485ecd104fb9 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/QueryPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/QueryPlanner.java @@ -823,7 +823,8 @@ private PlanBuilder planAggregation(PlanBuilder subPlan, List> grou ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), - groupIdSymbol); + groupIdSymbol, + Optional.empty()); return new PlanBuilder( subPlan.getTranslations() diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/AddIntermediateAggregations.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/AddIntermediateAggregations.java index 6b787648dc01..b8a6ea7e963b 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/AddIntermediateAggregations.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/AddIntermediateAggregations.java @@ -118,7 +118,8 @@ public Result apply(AggregationNode aggregation, Captures captures, Context cont aggregation.getPreGroupedSymbols(), AggregationNode.Step.INTERMEDIATE, aggregation.getHashSymbol(), - aggregation.getGroupIdSymbol()); + aggregation.getGroupIdSymbol(), + aggregation.getRawInputMaskSymbol()); source = ExchangeNode.gatheringExchange(idAllocator.getNextId(), ExchangeNode.Scope.LOCAL, source); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/AggregationDecorrelation.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/AggregationDecorrelation.java index d0ff406f812c..d1b9b9f6919a 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/AggregationDecorrelation.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/AggregationDecorrelation.java @@ -75,6 +75,7 @@ public static AggregationNode restoreDistinctAggregation( ImmutableList.of(), distinct.getStep(), Optional.empty(), - Optional.empty()); + Optional.empty(), + distinct.getRawInputMaskSymbol()); } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPartialAggregationThroughExchange.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPartialAggregationThroughExchange.java index c1c8e518dff9..1d4f75640f50 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPartialAggregationThroughExchange.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPartialAggregationThroughExchange.java @@ -47,6 +47,7 @@ import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; import static io.trino.SystemSessionProperties.preferPartialAggregation; +import static io.trino.spi.type.BooleanType.BOOLEAN; import static io.trino.sql.planner.plan.AggregationNode.Step.FINAL; import static io.trino.sql.planner.plan.AggregationNode.Step.PARTIAL; import static io.trino.sql.planner.plan.AggregationNode.Step.SINGLE; @@ -226,18 +227,16 @@ private PlanNode split(AggregationNode node, Context context) entry.getKey(), new AggregationNode.Aggregation( resolvedFunction, - ImmutableList.builder() - .add(intermediateSymbol.toSymbolReference()) - .addAll(originalAggregation.getArguments().stream() - .filter(LambdaExpression.class::isInstance) - .collect(toImmutableList())) - .build(), + buildFinalAggregationArguments(node.isGlobalAggregation(), originalAggregation, intermediateSymbol), false, Optional.empty(), Optional.empty(), - Optional.empty())); + node.isGlobalAggregation() ? Optional.empty() : originalAggregation.getMask())); } + Optional rawInputMaskSymbol = node.isGlobalAggregation() ? + Optional.empty() : + Optional.of(context.getSymbolAllocator().newSymbol("rawInputMaskSymbol", BOOLEAN)); PlanNode partial = new AggregationNode( context.getIdAllocator().getNextId(), node.getSource(), @@ -248,7 +247,8 @@ private PlanNode split(AggregationNode node, Context context) ImmutableList.of(), PARTIAL, node.getHashSymbol(), - node.getGroupIdSymbol()); + node.getGroupIdSymbol(), + rawInputMaskSymbol); return new AggregationNode( node.getId(), @@ -260,6 +260,23 @@ private PlanNode split(AggregationNode node, Context context) ImmutableList.of(), FINAL, node.getHashSymbol(), - node.getGroupIdSymbol()); + node.getGroupIdSymbol(), + rawInputMaskSymbol); + } + + private List buildFinalAggregationArguments(boolean isGlobalAggregation, AggregationNode.Aggregation originalAggregation, Symbol intermediateSymbol) + { + if (isGlobalAggregation) { + return ImmutableList.builder() + .add(intermediateSymbol.toSymbolReference()) + .addAll(originalAggregation.getArguments().stream() + .filter(LambdaExpression.class::isInstance) + .collect(toImmutableList())) + .build(); + } + return ImmutableList.builder() + .add(intermediateSymbol.toSymbolReference()) + .addAll(originalAggregation.getArguments()) + .build(); } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveRedundantDistinctLimit.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveRedundantDistinctLimit.java index 4c7cd3c75d20..41c69516e253 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveRedundantDistinctLimit.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveRedundantDistinctLimit.java @@ -67,6 +67,7 @@ public Result apply(DistinctLimitNode node, Captures captures, Context context) ImmutableList.of(), SINGLE, node.getHashSymbol(), + Optional.empty(), Optional.empty())); } return Result.empty(); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedDistinctAggregationWithProjection.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedDistinctAggregationWithProjection.java index 41856e4865f0..5be3acb464ac 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedDistinctAggregationWithProjection.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedDistinctAggregationWithProjection.java @@ -150,7 +150,8 @@ public Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Co ImmutableList.of(), aggregation.getStep(), Optional.empty(), - Optional.empty()); + Optional.empty(), + aggregation.getRawInputMaskSymbol()); // restrict outputs and apply projection Set outputSymbols = new HashSet<>(correlatedJoinNode.getOutputSymbols()); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedGlobalAggregationWithProjection.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedGlobalAggregationWithProjection.java index e6b3acce8bc9..9483d620d35e 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedGlobalAggregationWithProjection.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedGlobalAggregationWithProjection.java @@ -261,7 +261,8 @@ public Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Co ImmutableList.of(), globalAggregation.getStep(), Optional.empty(), - Optional.empty()); + Optional.empty(), + globalAggregation.getRawInputMaskSymbol()); // restrict outputs and apply projection Set outputSymbols = new HashSet<>(correlatedJoinNode.getOutputSymbols()); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedGlobalAggregationWithoutProjection.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedGlobalAggregationWithoutProjection.java index 4cb37fdb4685..5a83046d1db5 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedGlobalAggregationWithoutProjection.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedGlobalAggregationWithoutProjection.java @@ -254,7 +254,8 @@ public Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Co ImmutableList.of(), globalAggregation.getStep(), Optional.empty(), - Optional.empty()); + Optional.empty(), + globalAggregation.getRawInputMaskSymbol()); // restrict outputs Optional project = restrictOutputs(context.getIdAllocator(), globalAggregation, ImmutableSet.copyOf(correlatedJoinNode.getOutputSymbols())); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/OptimizeMixedDistinctAggregations.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/OptimizeMixedDistinctAggregations.java index 35a2a9ce67e4..deb2175cb075 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/OptimizeMixedDistinctAggregations.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/OptimizeMixedDistinctAggregations.java @@ -199,7 +199,8 @@ public PlanNode visitAggregation(AggregationNode node, RewriteContext hashSymbol; private final Optional groupIdSymbol; + // Mask symbol specifying if raw data or intermediate aggregation data is present in intermediate aggregation row. + // If empty, partial aggregation adaptation is disabled and intermediate state should be used. + // If the block is null at given position, the aggregated (intermediate) input should be used, + // otherwise raw input should be used (this happens when partial aggregation is disabled). + private final Optional rawInputMaskSymbol; private final List outputs; public static AggregationNode singleAggregation( @@ -62,7 +68,7 @@ public static AggregationNode singleAggregation( Map aggregations, GroupingSetDescriptor groupingSets) { - return new AggregationNode(id, source, aggregations, groupingSets, ImmutableList.of(), SINGLE, Optional.empty(), Optional.empty()); + return new AggregationNode(id, source, aggregations, groupingSets, ImmutableList.of(), SINGLE, Optional.empty(), Optional.empty(), Optional.empty()); } @JsonCreator @@ -74,20 +80,23 @@ public AggregationNode( @JsonProperty("preGroupedSymbols") List preGroupedSymbols, @JsonProperty("step") Step step, @JsonProperty("hashSymbol") Optional hashSymbol, - @JsonProperty("groupIdSymbol") Optional groupIdSymbol) + @JsonProperty("groupIdSymbol") Optional groupIdSymbol, + @JsonProperty("rawInputMaskSymbol") Optional rawInputMaskSymbol) { super(id); this.source = source; this.aggregations = ImmutableMap.copyOf(requireNonNull(aggregations, "aggregations is null")); - aggregations.values().forEach(aggregation -> aggregation.verifyArguments(step)); - requireNonNull(groupingSets, "groupingSets is null"); + boolean isGlobalAggregation = groupingSets.getGroupingKeys().isEmpty(); + aggregations.values().forEach(aggregation -> aggregation.verifyArguments(step, isGlobalAggregation)); + groupIdSymbol.ifPresent(symbol -> checkArgument(groupingSets.getGroupingKeys().contains(symbol), "Grouping columns does not contain groupId column")); this.groupingSets = groupingSets; this.groupIdSymbol = requireNonNull(groupIdSymbol); - + this.rawInputMaskSymbol = requireNonNull(rawInputMaskSymbol, "rawInputMaskSymbol is null"); + checkArgument(rawInputMaskSymbol.isPresent() || step == SINGLE || isGlobalAggregation); boolean noOrderBy = aggregations.values().stream() .map(Aggregation::getOrderingScheme) .noneMatch(Optional::isPresent); @@ -104,6 +113,24 @@ public AggregationNode( outputs.addAll(groupingSets.getGroupingKeys()); hashSymbol.ifPresent(outputs::add); outputs.addAll(aggregations.keySet()); + if (step.isOutputPartial() && !isGlobalAggregation) { + // add mask channels + aggregations.values().stream() + .map(Aggregation::getMask) + .flatMap(Optional::stream) + .distinct() + .forEach(outputs::add); + // add raw inputs to the aggregations to be used by adaptive partial aggregation. + // lambda arguments are ignored here since they are not real input just functions. + aggregations.values().stream() + .flatMap(Aggregation::getRawInputs) + .filter(symbol -> !groupingSets.getGroupingKeys().contains(symbol)) + .distinct() + .forEach(outputs::add); + + // add rawInputMask channel + outputs.add(rawInputMaskSymbol.orElseThrow()); + } this.outputs = outputs.build(); } @@ -200,6 +227,12 @@ public Optional getGroupIdSymbol() return groupIdSymbol; } + @JsonProperty("rawInputMaskSymbol") + public Optional getRawInputMaskSymbol() + { + return rawInputMaskSymbol; + } + public boolean hasOrderings() { return aggregations.values().stream() @@ -268,6 +301,11 @@ public boolean isStreamable() && groupingSets.getGlobalGroupingSets().isEmpty(); } + public boolean isGlobalAggregation() + { + return groupingSets.getGroupingKeys().isEmpty(); + } + public static GroupingSetDescriptor globalAggregation() { return singleGroupingSet(ImmutableList.of()); @@ -444,6 +482,13 @@ public Optional getMask() return mask; } + public Stream getRawInputs() + { + return getArguments().stream() + .filter(argument -> !(argument instanceof LambdaExpression)) + .map(Symbol::from); + } + @Override public boolean equals(Object o) { @@ -468,18 +513,22 @@ public int hashCode() return Objects.hash(resolvedFunction, arguments, distinct, filter, orderingScheme, mask); } - private void verifyArguments(Step step) + private void verifyArguments(Step step, boolean isGlobalAggregation) { int expectedArgumentCount; if (step == SINGLE || step == Step.PARTIAL) { expectedArgumentCount = resolvedFunction.getSignature().getArgumentTypes().size(); } - else { - // Intermediate and final steps get the intermediate value and the lambda functions + else if (isGlobalAggregation) { + // Global intermediate and final steps get the intermediate value and the lambda functions expectedArgumentCount = 1 + (int) resolvedFunction.getSignature().getArgumentTypes().stream() .filter(FunctionType.class::isInstance) .count(); } + else { + // Hash intermediate and final steps get the intermediate value and all the arguments + expectedArgumentCount = 1 + resolvedFunction.getSignature().getArgumentTypes().size(); + } checkArgument( expectedArgumentCount == arguments.size(), @@ -506,6 +555,7 @@ public static class Builder private Step step; private Optional hashSymbol; private Optional groupIdSymbol; + private Optional rawInputMaskSymbol; public Builder(AggregationNode node) { @@ -518,6 +568,7 @@ public Builder(AggregationNode node) this.step = node.getStep(); this.hashSymbol = node.getHashSymbol(); this.groupIdSymbol = node.getGroupIdSymbol(); + this.rawInputMaskSymbol = node.getRawInputMaskSymbol(); } public Builder setId(PlanNodeId id) @@ -568,6 +619,12 @@ public Builder setGroupIdSymbol(Optional groupIdSymbol) return this; } + public Builder setRawInputMaskSymbol(Optional rawInputMaskSymbol) + { + this.rawInputMaskSymbol = requireNonNull(rawInputMaskSymbol, "rawInputMaskSymbol is null"); + return this; + } + public AggregationNode build() { return new AggregationNode( @@ -578,7 +635,8 @@ public AggregationNode build() preGroupedSymbols, step, hashSymbol, - groupIdSymbol); + groupIdSymbol, + rawInputMaskSymbol); } } } diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/policy/PlanUtils.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/policy/PlanUtils.java index ce33f622fd3e..49c802ca6c37 100644 --- a/core/trino-main/src/test/java/io/trino/execution/scheduler/policy/PlanUtils.java +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/policy/PlanUtils.java @@ -41,7 +41,7 @@ import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION; import static io.trino.sql.planner.SystemPartitioningHandle.SOURCE_DISTRIBUTION; -import static io.trino.sql.planner.plan.AggregationNode.Step.FINAL; +import static io.trino.sql.planner.plan.AggregationNode.Step.SINGLE; import static io.trino.sql.planner.plan.AggregationNode.singleGroupingSet; import static io.trino.sql.planner.plan.ExchangeNode.Type.REPARTITION; import static io.trino.sql.planner.plan.ExchangeNode.Type.REPLICATE; @@ -91,7 +91,8 @@ static PlanFragment createAggregationFragment(String name, PlanFragment sourceFr ImmutableMap.of(), singleGroupingSet(ImmutableList.of()), ImmutableList.of(), - FINAL, + SINGLE, + Optional.empty(), Optional.empty(), Optional.empty()); diff --git a/core/trino-main/src/test/java/io/trino/operator/BenchmarkHashAndStreamingAggregationOperators.java b/core/trino-main/src/test/java/io/trino/operator/BenchmarkHashAndStreamingAggregationOperators.java index e25f28f029c9..2d7c96000662 100644 --- a/core/trino-main/src/test/java/io/trino/operator/BenchmarkHashAndStreamingAggregationOperators.java +++ b/core/trino-main/src/test/java/io/trino/operator/BenchmarkHashAndStreamingAggregationOperators.java @@ -19,6 +19,7 @@ import io.trino.jmh.Benchmarks; import io.trino.metadata.TestingFunctionResolution; import io.trino.operator.HashAggregationOperator.HashAggregationOperatorFactory; +import io.trino.operator.aggregation.AggregatorFactory; import io.trino.operator.aggregation.TestingAggregationFunction; import io.trino.spi.Page; import io.trino.spi.block.Block; @@ -246,17 +247,19 @@ private OperatorFactory createHashAggregationOperatorFactory( { SpillerFactory spillerFactory = (types, localSpillContext, aggregatedMemoryContext) -> null; + ImmutableList aggregatorFactories = ImmutableList.of( + COUNT.createAggregatorFactory(SINGLE, ImmutableList.of(0), OptionalInt.empty()), + LONG_SUM.createAggregatorFactory(SINGLE, ImmutableList.of(sumChannel), OptionalInt.empty())); return new HashAggregationOperatorFactory( 0, new PlanNodeId("test"), hashTypes, hashChannels, + Optional.empty(), ImmutableList.of(), SINGLE, false, - ImmutableList.of( - COUNT.createAggregatorFactory(SINGLE, ImmutableList.of(0), OptionalInt.empty()), - LONG_SUM.createAggregatorFactory(SINGLE, ImmutableList.of(sumChannel), OptionalInt.empty())), + aggregatorFactories, hashChannel, Optional.empty(), 100_000, diff --git a/core/trino-main/src/test/java/io/trino/operator/TestHashAggregationOperator.java b/core/trino-main/src/test/java/io/trino/operator/TestHashAggregationOperator.java index 3f05b637aaf2..e77bcfd2354a 100644 --- a/core/trino-main/src/test/java/io/trino/operator/TestHashAggregationOperator.java +++ b/core/trino-main/src/test/java/io/trino/operator/TestHashAggregationOperator.java @@ -29,8 +29,12 @@ import io.trino.operator.aggregation.builder.InMemoryHashAggregationBuilder; import io.trino.operator.aggregation.partial.PartialAggregationController; import io.trino.spi.Page; +import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.PageBuilderStatus; +import io.trino.spi.block.RowBlockBuilder; +import io.trino.spi.block.RunLengthEncodedBlock; +import io.trino.spi.block.SingleRowBlockWriter; import io.trino.spi.type.Type; import io.trino.spi.type.TypeOperators; import io.trino.spiller.Spiller; @@ -68,6 +72,8 @@ import static io.airlift.units.DataSize.succinctBytes; import static io.trino.RowPagesBuilder.rowPagesBuilder; import static io.trino.SessionTestUtils.TEST_SESSION; +import static io.trino.block.BlockAssertions.createBooleansBlock; +import static io.trino.block.BlockAssertions.createLongSequenceBlock; import static io.trino.block.BlockAssertions.createLongsBlock; import static io.trino.block.BlockAssertions.createRLEBlock; import static io.trino.operator.GroupByHashYieldAssertion.GroupByHashYieldResult; @@ -83,6 +89,7 @@ import static io.trino.spi.type.DoubleType.DOUBLE; import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; +import static io.trino.sql.planner.plan.AggregationNode.Step.FINAL; import static io.trino.sql.planner.plan.AggregationNode.Step.PARTIAL; import static io.trino.sql.planner.plan.AggregationNode.Step.SINGLE; import static io.trino.testing.MaterializedResult.resultBuilder; @@ -180,6 +187,7 @@ public void testHashAggregation(boolean hashEnabled, boolean spillEnabled, boole new PlanNodeId("test"), ImmutableList.of(VARCHAR), hashChannels, + Optional.empty(), ImmutableList.of(), SINGLE, false, @@ -234,6 +242,7 @@ public void testHashAggregationWithGlobals(boolean hashEnabled, boolean spillEna new PlanNodeId("test"), ImmutableList.of(VARCHAR, BIGINT), groupByChannels, + Optional.empty(), globalAggregationGroupIds, SINGLE, true, @@ -286,6 +295,7 @@ public void testHashAggregationMemoryReservation(boolean hashEnabled, boolean sp new PlanNodeId("test"), ImmutableList.of(BIGINT), hashChannels, + Optional.empty(), ImmutableList.of(), SINGLE, true, @@ -332,6 +342,8 @@ public void testMemoryLimit(boolean hashEnabled) ImmutableList.of(BIGINT), hashChannels, ImmutableList.of(), + ImmutableList.of(), + ImmutableList.of(), SINGLE, ImmutableList.of(COUNT.createAggregatorFactory(SINGLE, ImmutableList.of(0), OptionalInt.empty()), LONG_MIN.createAggregatorFactory(SINGLE, ImmutableList.of(3), OptionalInt.empty()), @@ -370,6 +382,7 @@ public void testHashBuilderResize(boolean hashEnabled, boolean spillEnabled, boo new PlanNodeId("test"), ImmutableList.of(VARCHAR), hashChannels, + Optional.empty(), ImmutableList.of(), SINGLE, false, @@ -399,6 +412,8 @@ public void testMemoryReservationYield(Type type) ImmutableList.of(type), ImmutableList.of(0), ImmutableList.of(), + ImmutableList.of(), + ImmutableList.of(), SINGLE, ImmutableList.of(COUNT.createAggregatorFactory(SINGLE, ImmutableList.of(0), OptionalInt.empty())), Optional.of(1), @@ -452,6 +467,8 @@ public void testHashBuilderResizeLimit(boolean hashEnabled) ImmutableList.of(VARCHAR), hashChannels, ImmutableList.of(), + ImmutableList.of(), + ImmutableList.of(), SINGLE, ImmutableList.of(COUNT.createAggregatorFactory(SINGLE, ImmutableList.of(0), OptionalInt.empty())), rowPagesBuilder.getHashChannel(), @@ -486,6 +503,8 @@ public void testMultiSliceAggregationOutput(boolean hashEnabled) ImmutableList.of(BIGINT), hashChannels, ImmutableList.of(), + ImmutableList.of(), + ImmutableList.of(), SINGLE, ImmutableList.of(COUNT.createAggregatorFactory(SINGLE, ImmutableList.of(0), OptionalInt.empty()), LONG_AVERAGE.createAggregatorFactory(SINGLE, ImmutableList.of(1), OptionalInt.empty())), @@ -519,6 +538,8 @@ public void testMultiplePartialFlushes(boolean hashEnabled) ImmutableList.of(BIGINT), hashChannels, ImmutableList.of(), + ImmutableList.of(), + ImmutableList.of(), PARTIAL, ImmutableList.of(LONG_MIN.createAggregatorFactory(PARTIAL, ImmutableList.of(0), OptionalInt.empty())), rowPagesBuilder.getHashChannel(), @@ -532,10 +553,10 @@ public void testMultiplePartialFlushes(boolean hashEnabled) DriverContext driverContext = createDriverContext(1024); try (Operator operator = operatorFactory.createOperator(driverContext)) { - List expectedPages = rowPagesBuilder(BIGINT, BIGINT) - .addSequencePage(2000, 0, 0) + List expectedPages = rowPagesBuilder(BIGINT, BIGINT, BOOLEAN) + .addBlocksPage(createLongSequenceBlock(0, 2000), createLongSequenceBlock(0, 2000), nullRle(BOOLEAN, 2000)) .build(); - MaterializedResult expected = resultBuilder(driverContext.getSession(), BIGINT, BIGINT) + MaterializedResult expected = resultBuilder(driverContext.getSession(), BIGINT, BIGINT, BOOLEAN) .pages(expectedPages) .build(); @@ -599,6 +620,7 @@ public void testMergeWithMemorySpill() new PlanNodeId("test"), ImmutableList.of(BIGINT), ImmutableList.of(0), + Optional.empty(), ImmutableList.of(), SINGLE, false, @@ -651,6 +673,7 @@ public void testSpillerFailure() new PlanNodeId("test"), ImmutableList.of(BIGINT), hashChannels, + Optional.empty(), ImmutableList.of(), SINGLE, false, @@ -690,6 +713,8 @@ public void testMemoryTracking() ImmutableList.of(BIGINT), hashChannels, ImmutableList.of(), + ImmutableList.of(), + ImmutableList.of(), SINGLE, ImmutableList.of(LONG_MIN.createAggregatorFactory(SINGLE, ImmutableList.of(0), OptionalInt.empty())), rowPagesBuilder.getHashChannel(), @@ -726,9 +751,11 @@ public void testAdaptivePartialAggregation() new PlanNodeId("test"), ImmutableList.of(BIGINT), hashChannels, + ImmutableList.of(BIGINT), + ImmutableList.of(1), ImmutableList.of(), PARTIAL, - ImmutableList.of(LONG_MIN.createAggregatorFactory(PARTIAL, ImmutableList.of(0), OptionalInt.empty())), + ImmutableList.of(LONG_MIN.createAggregatorFactory(PARTIAL, ImmutableList.of(1), OptionalInt.of(2))), Optional.empty(), Optional.empty(), 100, @@ -742,26 +769,40 @@ public void testAdaptivePartialAggregation() assertFalse(partialAggregationController.isPartialAggregationDisabled()); // First operator will trigger adaptive partial aggregation after the first page List operator1Input = rowPagesBuilder(false, hashChannels, BIGINT) - .addBlocksPage(createLongsBlock(0, 1, 2, 3, 4, 5, 6, 7, 8, 8)) // first page will be hashed but the values are almost unique, so it will trigger adaptation - .addBlocksPage(createRLEBlock(1, 10)) // second page would be hashed to existing value 1. but if adaptive PA kicks in, the raw values will be passed on - .build(); - List operator1Expected = rowPagesBuilder(BIGINT, BIGINT) - .addBlocksPage(createLongsBlock(0, 1, 2, 3, 4, 5, 6, 7, 8), createLongsBlock(0, 1, 2, 3, 4, 5, 6, 7, 8)) // the last position was aggregated - .addBlocksPage(createRLEBlock(1, 10), createRLEBlock(1, 10)) // we are expecting second page with raw values + .addBlocksPage(// first page will be hashed but the values are almost unique, so it will trigger adaptation + createLongsBlock(0, 1, 2, 3, 4, 5, 6, 7, 8, 8), // hash channel + createRLEBlock(0, 10), // aggregation input + booleanRle(true, 10)) // mask channel + .addBlocksPage(// second page would be hashed to existing value 1. but if adaptive PA kicks in, the raw values will be passed on + createRLEBlock(1, 10), // hash channel + createRLEBlock(0, 10), // aggregation input + booleanRle(false, 10)) // mask channel .build(); + RowPagesBuilder operator1Expected = rowPagesBuilder(BIGINT, BIGINT, BOOLEAN, BIGINT, BOOLEAN) + .addBlocksPage( + createLongsBlock(0, 1, 2, 3, 4, 5, 6, 7, 8), // group ids, the last position was aggregated + createRLEBlock(0, 9), // intermediate state + nullRle(BOOLEAN, 10), // mask channel + nullRle(BIGINT, 9), // aggregation input + nullRle(BOOLEAN, 9)) // raw input mask + .addBlocksPage( + createRLEBlock(1, 10), // group ids, we are expecting second page with raw values not aggregated + nullRle(BIGINT, 10), // intermediate state + booleanRle(false, 10), // mask channel + createRLEBlock(0, 10), // raw aggregation input + booleanRle(true, 10)); // raw input mask assertOperatorEquals(operatorFactory, operator1Input, operator1Expected); // the first operator flush disables partial aggregation assertTrue(partialAggregationController.isPartialAggregationDisabled()); // second operator using the same factory, reuses PartialAggregationControl, so it will only produce raw pages (partial aggregation is disabled at this point) List operator2Input = rowPagesBuilder(false, hashChannels, BIGINT) - .addBlocksPage(createRLEBlock(1, 10)) - .addBlocksPage(createRLEBlock(2, 10)) - .build(); - List operator2Expected = rowPagesBuilder(BIGINT, BIGINT) - .addBlocksPage(createRLEBlock(1, 10), createRLEBlock(1, 10)) - .addBlocksPage(createRLEBlock(2, 10), createRLEBlock(2, 10)) + .addBlocksPage(createRLEBlock(1, 10), createRLEBlock(0, 10), booleanRle(false, 10)) + .addBlocksPage(createRLEBlock(2, 10), createRLEBlock(0, 10), booleanRle(false, 10)) .build(); + RowPagesBuilder operator2Expected = rowPagesBuilder(BIGINT, BIGINT, BOOLEAN, BIGINT, BOOLEAN) + .addBlocksPage(createRLEBlock(1, 10), nullRle(BIGINT, 10), booleanRle(false, 10), createRLEBlock(0, 10), booleanRle(true, 10)) + .addBlocksPage(createRLEBlock(2, 10), nullRle(BIGINT, 10), booleanRle(false, 10), createRLEBlock(0, 10), booleanRle(true, 10)); assertOperatorEquals(operatorFactory, operator2Input, operator2Expected); } @@ -778,6 +819,8 @@ public void testAdaptivePartialAggregationTriggeredOnlyOnFlush() ImmutableList.of(BIGINT), hashChannels, ImmutableList.of(), + ImmutableList.of(), + ImmutableList.of(), PARTIAL, ImmutableList.of(LONG_MIN.createAggregatorFactory(PARTIAL, ImmutableList.of(0), OptionalInt.empty())), Optional.empty(), @@ -793,9 +836,9 @@ public void testAdaptivePartialAggregationTriggeredOnlyOnFlush() .addSequencePage(10, 0) // first page are unique values, so it would trigger adaptation, but it won't because flush is not called .addBlocksPage(createRLEBlock(1, 2)) // second page will be hashed to existing value 1 .build(); - // the total unique ows ratio for the first operator will be 10/12 so > 0.8 (adaptive partial aggregation uniqueRowsRatioThreshold) - List operator1Expected = rowPagesBuilder(BIGINT, BIGINT) - .addSequencePage(10, 0, 0) // we are expecting second page to be squashed with the first + // the total unique rows ratio for the first operator will be 10/12 so > 0.8 (adaptive partial aggregation uniqueRowsRatioThreshold) + List operator1Expected = rowPagesBuilder(BIGINT, BIGINT, BOOLEAN) + .addBlocksPage(createLongSequenceBlock(0, 10), createLongSequenceBlock(0, 10), nullRle(BOOLEAN, 10)) // we are expecting second page to be squashed with the first .build(); assertOperatorEquals(operatorFactory, operator1Input, operator1Expected); @@ -807,23 +850,161 @@ public void testAdaptivePartialAggregationTriggeredOnlyOnFlush() .addBlocksPage(createRLEBlock(1, 10)) .addBlocksPage(createRLEBlock(2, 10)) .build(); - List operator2Expected = rowPagesBuilder(BIGINT, BIGINT) - .addBlocksPage(createRLEBlock(1, 10), createRLEBlock(1, 10)) - .addBlocksPage(createRLEBlock(2, 10), createRLEBlock(2, 10)) + List operator2Expected = rowPagesBuilder(BIGINT, BIGINT, BOOLEAN) + .addBlocksPage(createRLEBlock(1, 10), nullRle(BIGINT, 10), booleanRle(true, 10)) + .addBlocksPage(createRLEBlock(2, 10), nullRle(BIGINT, 10), booleanRle(true, 10)) .build(); assertOperatorEquals(operatorFactory, operator2Input, operator2Expected); } + @Test + public void testFinalAggregation() + { + List hashChannels = Ints.asList(0); + + HashAggregationOperatorFactory operatorFactory = new HashAggregationOperatorFactory( + 0, + new PlanNodeId("test"), + ImmutableList.of(BIGINT), + hashChannels, + ImmutableList.of(BIGINT), + ImmutableList.of(2), + ImmutableList.of(), + FINAL, + ImmutableList.of(LONG_SUM.createAggregatorFactory(FINAL, ImmutableList.of(1, 2), OptionalInt.empty(), OptionalInt.of(3))), + Optional.empty(), + Optional.empty(), + 100, + Optional.of(DataSize.of(16, MEGABYTE)), + joinCompiler, + blockTypeOperators, + Optional.empty()); + + List input = rowPagesBuilder(false, hashChannels, BIGINT, BIGINT, BIGINT, BOOLEAN) + .addBlocksPage(// aggregated page + createLongsBlock(0, 1, 2, 3), // hash channels + serializedLongSum(2, 2, 3, 3), // intermediate state + nullRle(BIGINT, 4), // aggregation input + nullRle(BOOLEAN, 4)) // raw input mask + .addBlocksPage(// raw input page + createLongsBlock(0, 0, 2, 2), // hash channels + nullRle(BIGINT, 4), // intermediate state + createLongsBlock(1, 2, -1, -2), // aggregation input + booleanRle(true, 4)) // raw input mask + .addBlocksPage(// mixed aggregated and raw input page + createLongsBlock(2, 3, 3, 3), // hash channels + serializedLongSum(1, 1, null, null), // intermediate state + createLongsBlock(null, null, 1L, 1L), // aggregation input + createBooleansBlock(null, null, true, true)) // raw input mask + .build(); + RowPagesBuilder expected = rowPagesBuilder(BIGINT, BIGINT) + .addBlocksPage( + createLongsBlock(0, 1, 2, 3), // hash channels + createLongsBlock(2 + 1 + 2, 2, 3 - 1 + -2 + 1, 3 + 1 + 1 + 1)); // aggregation result + + assertOperatorEquals(operatorFactory, input, expected); + } + + @Test + public void testFinalAggregationWithMask() + { + List hashChannels = Ints.asList(0); + + HashAggregationOperatorFactory operatorFactory = new HashAggregationOperatorFactory( + 0, + new PlanNodeId("test"), + ImmutableList.of(BIGINT), + hashChannels, + ImmutableList.of(BIGINT), + ImmutableList.of(2), + ImmutableList.of(), + FINAL, + ImmutableList.of(LONG_SUM.createAggregatorFactory(FINAL, ImmutableList.of(1, 3), OptionalInt.of(2), OptionalInt.of(4))), + Optional.empty(), + Optional.empty(), + 100, + Optional.of(DataSize.of(16, MEGABYTE)), + joinCompiler, + blockTypeOperators, + Optional.empty()); + + List input = rowPagesBuilder(false, hashChannels, BIGINT, BIGINT, BIGINT, BOOLEAN) + .addBlocksPage(// aggregated page + createLongsBlock(0, 1, 2, 3), // hash channels + serializedLongSum(2, 2, 3, 3), // intermediate state + nullRle(BOOLEAN, 4), // mask channel + nullRle(BIGINT, 4), // aggregation input + nullRle(BOOLEAN, 4)) // raw input mask + .addBlocksPage(// raw input page + createLongsBlock(0, 0, 2, 2), // hash channels + nullRle(BIGINT, 4), // intermediate state + booleanRle(true, 4), // mask channel + createLongsBlock(1, 2, -1, -2), // aggregation input + booleanRle(true, 4)) // raw input mask + .addBlocksPage(// mixed aggregated and raw input page + createLongsBlock(2, 3, 3, 3), // hash channels + serializedLongSum(1, null, null, null), // intermediate state + createBooleansBlock(null, false, true, true), // mask channel + createLongsBlock(null, 100L, 1L, 1L), // aggregation input + createBooleansBlock(null, true, true, true)) // raw input mask + .build(); + RowPagesBuilder expected = rowPagesBuilder(BIGINT, BIGINT) + .addBlocksPage( + createLongsBlock(0, 1, 2, 3), // hash channels + createLongsBlock(2 + 1 + 2, 2, 3 - 1 + -2 + 1, 3 + 1 + 1)); // aggregation result + + assertOperatorEquals(operatorFactory, input, expected); + } + + private Block serializedLongSum(Integer... values) + { + RowBlockBuilder builder = new RowBlockBuilder(ImmutableList.of(BIGINT, BOOLEAN, BIGINT, BOOLEAN), null, values.length); + for (int i = 0; i < values.length; i++) { + if (values[i] == null) { + builder.appendNull(); + } + else { + SingleRowBlockWriter rowWriter = builder.beginBlockEntry(); + BIGINT.writeLong(rowWriter, 1); + BOOLEAN.writeBoolean(rowWriter, false); + BIGINT.writeLong(rowWriter, values[i]); + BOOLEAN.writeBoolean(rowWriter, false); + builder.closeEntry(); + } + } + return builder.build(); + } + + private void assertOperatorEquals(OperatorFactory operatorFactory, List input, RowPagesBuilder expectedPages) + { + assertOperatorEquals(operatorFactory, input, expectedPages.build(), expectedPages.getTypes()); + } + private void assertOperatorEquals(OperatorFactory operatorFactory, List input, List expectedPages) + { + assertOperatorEquals(operatorFactory, input, expectedPages, ImmutableList.of(BIGINT, BIGINT, BOOLEAN)); + } + + private void assertOperatorEquals(OperatorFactory operatorFactory, List input, List expectedPages, List expectedTypes) { DriverContext driverContext = createDriverContext(1024); - MaterializedResult expected = resultBuilder(driverContext.getSession(), BIGINT, BIGINT) + MaterializedResult expected = resultBuilder(driverContext.getSession(), expectedTypes) .pages(expectedPages) .build(); OperatorAssertion.assertOperatorEquals(operatorFactory, driverContext, input, expected, false, false); } + private static Block booleanRle(boolean value, int positionCount) + { + return RunLengthEncodedBlock.create(BOOLEAN, value, positionCount); + } + + private static Block nullRle(Type type, int positionCount) + { + return RunLengthEncodedBlock.create(type, null, positionCount); + } + private DriverContext createDriverContext() { return createDriverContext(Integer.MAX_VALUE); diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestingAggregationFunction.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestingAggregationFunction.java index eb9ea8e3fe91..c76436e55bed 100644 --- a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestingAggregationFunction.java +++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestingAggregationFunction.java @@ -97,15 +97,20 @@ public Type getFinalType() public AggregatorFactory createAggregatorFactory(Step step, List inputChannels, OptionalInt maskChannel) { - return createAggregatorFactory(step, inputChannels, maskChannel, factory); + return createAggregatorFactory(step, inputChannels, maskChannel, OptionalInt.empty(), factory); + } + + public AggregatorFactory createAggregatorFactory(Step step, List inputChannels, OptionalInt maskChannel, OptionalInt rawInputMaskChannel) + { + return createAggregatorFactory(step, inputChannels, maskChannel, rawInputMaskChannel, factory); } public AggregatorFactory createDistinctAggregatorFactory(Step step, List inputChannels, OptionalInt maskChannel) { - return createAggregatorFactory(step, inputChannels, maskChannel, distinctFactory); + return createAggregatorFactory(step, inputChannels, maskChannel, OptionalInt.empty(), distinctFactory); } - private AggregatorFactory createAggregatorFactory(Step step, List inputChannels, OptionalInt maskChannel, AccumulatorFactory distinctFactory) + private AggregatorFactory createAggregatorFactory(Step step, List inputChannels, OptionalInt maskChannel, OptionalInt rawInputMaskChannel, AccumulatorFactory distinctFactory) { return new AggregatorFactory( distinctFactory, @@ -113,6 +118,7 @@ private AggregatorFactory createAggregatorFactory(Step step, List input intermediateType, finalType, inputChannels, + rawInputMaskChannel, maskChannel, true, ImmutableList.of()); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestEffectivePredicateExtractor.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestEffectivePredicateExtractor.java index 0ba894957661..b21cd5bcdb20 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestEffectivePredicateExtractor.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestEffectivePredicateExtractor.java @@ -260,7 +260,8 @@ public void testGroupByEmpty() ImmutableMap.of(), globalAggregation(), ImmutableList.of(), - AggregationNode.Step.FINAL, + AggregationNode.Step.SINGLE, + Optional.empty(), Optional.empty(), Optional.empty()); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestLogicalPlanner.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestLogicalPlanner.java index b523aa84e1ab..b3f9043db830 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestLogicalPlanner.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestLogicalPlanner.java @@ -221,7 +221,7 @@ public void testAggregation() assertDistributedPlan("SELECT orderstatus, sum(totalprice) FROM orders GROUP BY orderstatus", anyTree( aggregation( - ImmutableMap.of("final_sum", functionCall("sum", ImmutableList.of("partial_sum"))), + ImmutableMap.of("final_sum", functionCall("sum", ImmutableList.of("partial_sum", "totalprice"))), FINAL, exchange(LOCAL, GATHER, exchange(REMOTE, REPARTITION, @@ -234,7 +234,7 @@ public void testAggregation() assertDistributedPlan("SELECT orderstatus, sum(totalprice) FROM orders WHERE orderstatus='O' GROUP BY orderstatus", anyTree( aggregation( - ImmutableMap.of("final_sum", functionCall("sum", ImmutableList.of("partial_sum"))), + ImmutableMap.of("final_sum", functionCall("sum", ImmutableList.of("partial_sum", "totalprice"))), FINAL, exchange(LOCAL, GATHER, exchange(REMOTE, REPARTITION, @@ -1739,7 +1739,7 @@ public void testGroupingSetsWithDefaultValue() output( anyTree( aggregation( - ImmutableMap.of("final_count", functionCall("count", ImmutableList.of("partial_count"))), + ImmutableMap.of("final_count", functionCall("count", ImmutableList.of("partial_count", "CONSTANT"))), FINAL, exchange( LOCAL, diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestTableScanNodePartitioning.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestTableScanNodePartitioning.java index 4a1f065e7778..995a6cd1ed98 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestTableScanNodePartitioning.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestTableScanNodePartitioning.java @@ -172,7 +172,7 @@ void assertTableScanPlannedWithPartitioning(Session session, String table, Conne String query = "SELECT count(column_b) FROM " + table + " GROUP BY column_a"; assertDistributedPlan(query, session, anyTree( - aggregation(ImmutableMap.of("COUNT", functionCall("count", ImmutableList.of("COUNT_PART"))), FINAL, + aggregation(ImmutableMap.of("COUNT", functionCall("count", ImmutableList.of("COUNT_PART", "B"))), FINAL, exchange(LOCAL, REPARTITION, project( aggregation(ImmutableMap.of("COUNT_PART", functionCall("count", ImmutableList.of("B"))), PARTIAL, @@ -187,7 +187,7 @@ void assertTableScanPlannedWithoutPartitioning(Session session, String table) String query = "SELECT count(column_b) FROM " + table + " GROUP BY column_a"; assertDistributedPlan("SELECT count(column_b) FROM " + table + " GROUP BY column_a", session, anyTree( - aggregation(ImmutableMap.of("COUNT", functionCall("count", ImmutableList.of("COUNT_PART"))), FINAL, + aggregation(ImmutableMap.of("COUNT", functionCall("count", ImmutableList.of("COUNT_PART", "B"))), FINAL, exchange(LOCAL, REPARTITION, exchange(REMOTE, REPARTITION, project( diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestAddIntermediateAggregations.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestAddIntermediateAggregations.java index 297160b89322..bd8e99901297 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestAddIntermediateAggregations.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestAddIntermediateAggregations.java @@ -18,7 +18,6 @@ import io.trino.sql.planner.assertions.ExpectedValueProvider; import io.trino.sql.planner.assertions.PlanMatchPattern; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; -import io.trino.sql.planner.plan.AggregationNode; import io.trino.sql.planner.plan.Assignments; import io.trino.sql.planner.plan.ExchangeNode; import io.trino.sql.tree.FunctionCall; @@ -57,13 +56,12 @@ public void testBasic() .setSystemProperty(TASK_CONCURRENCY, "4") .on(p -> p.aggregation(af -> { af.globalGrouping() - .step(AggregationNode.Step.FINAL) + .finalAggregation() .addAggregation(p.symbol("c"), expression("count(b)"), ImmutableList.of(BIGINT)) .source( p.gatheringExchange( ExchangeNode.Scope.REMOTE, - p.aggregation(ap -> ap.globalGrouping() - .step(AggregationNode.Step.PARTIAL) + af.partialAggregation(ap -> ap.globalGrouping() .addAggregation(p.symbol("b"), expression("count(a)"), ImmutableList.of(BIGINT)) .source( p.values(p.symbol("a")))))); @@ -108,13 +106,12 @@ public void testNoInputCount() .setSystemProperty(TASK_CONCURRENCY, "4") .on(p -> p.aggregation(af -> { af.globalGrouping() - .step(AggregationNode.Step.FINAL) + .finalAggregation() .addAggregation(p.symbol("c"), expression("count(b)"), ImmutableList.of(BIGINT)) .source( p.gatheringExchange( ExchangeNode.Scope.REMOTE, - p.aggregation(ap -> ap.globalGrouping() - .step(AggregationNode.Step.PARTIAL) + af.partialAggregation(ap -> ap.globalGrouping() .addAggregation(p.symbol("b"), expression("count(*)"), ImmutableList.of()) .source( p.values(p.symbol("a")))))); @@ -157,15 +154,14 @@ public void testMultipleExchanges() .setSystemProperty(TASK_CONCURRENCY, "4") .on(p -> p.aggregation(af -> { af.globalGrouping() - .step(AggregationNode.Step.FINAL) + .finalAggregation() .addAggregation(p.symbol("c"), expression("count(b)"), ImmutableList.of(BIGINT)) .source( p.gatheringExchange( ExchangeNode.Scope.REMOTE, p.gatheringExchange( ExchangeNode.Scope.REMOTE, - p.aggregation(ap -> ap.globalGrouping() - .step(AggregationNode.Step.PARTIAL) + af.partialAggregation(ap -> ap.globalGrouping() .addAggregation(p.symbol("b"), expression("count(a)"), ImmutableList.of(BIGINT)) .source( p.values(p.symbol("a"))))))); @@ -207,13 +203,12 @@ public void testSessionDisable() .setSystemProperty(TASK_CONCURRENCY, "4") .on(p -> p.aggregation(af -> { af.globalGrouping() - .step(AggregationNode.Step.FINAL) + .finalAggregation() .addAggregation(p.symbol("c"), expression("count(b)"), ImmutableList.of(BIGINT)) .source( p.gatheringExchange( ExchangeNode.Scope.REMOTE, - p.aggregation(ap -> ap.globalGrouping() - .step(AggregationNode.Step.PARTIAL) + af.partialAggregation(ap -> ap.globalGrouping() .addAggregation(p.symbol("b"), expression("count(a)"), ImmutableList.of(BIGINT)) .source( p.values(p.symbol("a")))))); @@ -231,13 +226,12 @@ public void testNoLocalParallel() .setSystemProperty(TASK_CONCURRENCY, "1") .on(p -> p.aggregation(af -> { af.globalGrouping() - .step(AggregationNode.Step.FINAL) + .finalAggregation() .addAggregation(p.symbol("c"), expression("count(b)"), ImmutableList.of(BIGINT)) .source( p.gatheringExchange( ExchangeNode.Scope.REMOTE, - p.aggregation(ap -> ap.globalGrouping() - .step(AggregationNode.Step.PARTIAL) + af.partialAggregation(ap -> ap.globalGrouping() .addAggregation(p.symbol("b"), expression("count(a)"), ImmutableList.of(BIGINT)) .source( p.values(p.symbol("a")))))); @@ -271,13 +265,12 @@ public void testWithGroups() .setSystemProperty(TASK_CONCURRENCY, "4") .on(p -> p.aggregation(af -> { af.singleGroupingSet(p.symbol("c")) - .step(AggregationNode.Step.FINAL) - .addAggregation(p.symbol("c"), expression("count(b)"), ImmutableList.of(BIGINT)) + .finalAggregation() + .addFinalAggregation(p.symbol("c"), expression("count(b)"), ImmutableList.of(BIGINT), ImmutableList.of(p.symbol("a"))) .source( p.gatheringExchange( ExchangeNode.Scope.REMOTE, - p.aggregation(ap -> ap.singleGroupingSet(p.symbol("b")) - .step(AggregationNode.Step.PARTIAL) + af.partialAggregation(ap -> ap.singleGroupingSet(p.symbol("b")) .addAggregation(p.symbol("b"), expression("count(a)"), ImmutableList.of(BIGINT)) .source( p.values(p.symbol("a")))))); @@ -295,15 +288,14 @@ public void testInterimProject() .setSystemProperty(TASK_CONCURRENCY, "4") .on(p -> p.aggregation(af -> { af.globalGrouping() - .step(AggregationNode.Step.FINAL) + .finalAggregation() .addAggregation(p.symbol("c"), expression("count(b)"), ImmutableList.of(BIGINT)) .source( p.gatheringExchange( ExchangeNode.Scope.REMOTE, p.project( Assignments.identity(p.symbol("b")), - p.aggregation(ap -> ap.globalGrouping() - .step(AggregationNode.Step.PARTIAL) + af.partialAggregation(ap -> ap.globalGrouping() .addAggregation(p.symbol("b"), expression("count(a)"), ImmutableList.of(BIGINT)) .source( p.values(p.symbol("a"))))))); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushPartialAggregationThroughJoin.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushPartialAggregationThroughJoin.java index b0f615dcb2e1..30f3d0d915a6 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushPartialAggregationThroughJoin.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushPartialAggregationThroughJoin.java @@ -57,7 +57,8 @@ public void testPushesPartialAggregationThroughJoin() Optional.of(p.symbol("RIGHT_HASH")))) .addAggregation(p.symbol("AVG", DOUBLE), expression("AVG(LEFT_AGGR)"), ImmutableList.of(DOUBLE)) .singleGroupingSet(p.symbol("LEFT_GROUP_BY"), p.symbol("RIGHT_GROUP_BY")) - .step(PARTIAL))) + .step(PARTIAL) + .rawInputMaskSymbol())) .matches(project(ImmutableMap.of( "LEFT_GROUP_BY", PlanMatchPattern.expression("LEFT_GROUP_BY"), "RIGHT_GROUP_BY", PlanMatchPattern.expression("RIGHT_GROUP_BY"), diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/PlanBuilder.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/PlanBuilder.java index 1bc3c0045814..edb9553e2621 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/PlanBuilder.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/PlanBuilder.java @@ -124,6 +124,7 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableMap.toImmutableMap; import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.spi.type.BooleanType.BOOLEAN; import static io.trino.spi.type.VarbinaryType.VARBINARY; import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION; import static io.trino.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION; @@ -398,6 +399,7 @@ public class AggregationBuilder private Optional hashSymbol = Optional.empty(); private Optional groupIdSymbol = Optional.empty(); private Optional nodeId = Optional.empty(); + private Optional rawInputMaskSymbol = Optional.empty(); public AggregationBuilder source(PlanNode source) { @@ -429,6 +431,28 @@ private AggregationBuilder addAggregation(Symbol output, Expression expression, mask)); } + public AggregationBuilder addFinalAggregation(Symbol output, Expression expression, List inputTypes, List rawInputs) + { + return addFinalAggregation(output, expression, inputTypes, Optional.empty(), rawInputs); + } + + private AggregationBuilder addFinalAggregation(Symbol output, Expression expression, List inputTypes, Optional mask, List rawInputs) + { + checkArgument(expression instanceof FunctionCall); + FunctionCall aggregation = (FunctionCall) expression; + ResolvedFunction resolvedFunction = metadata.resolveFunction(session, aggregation.getName(), TypeSignatureProvider.fromTypes(inputTypes)); + return addAggregation(output, new Aggregation( + resolvedFunction, + ImmutableList.builder() + .addAll(aggregation.getArguments()) + .addAll(rawInputs.stream().map(Symbol::toSymbolReference).collect(toImmutableList())) + .build(), + aggregation.isDistinct(), + aggregation.getFilter().map(Symbol::from), + aggregation.getOrderBy().map(OrderingScheme::fromOrderBy), + mask)); + } + public AggregationBuilder addAggregation(Symbol output, Aggregation aggregation) { assignments.put(output, aggregation); @@ -485,6 +509,22 @@ public AggregationBuilder nodeId(PlanNodeId nodeId) return this; } + public AggregationBuilder rawInputMaskSymbol() + { + return rawInputMaskSymbol(symbol("rawInputMask", BOOLEAN)); + } + + public AggregationBuilder rawInputMaskSymbol(Symbol rawInputMaskSymbol) + { + return rawInputMaskSymbol(Optional.of(rawInputMaskSymbol)); + } + + private AggregationBuilder rawInputMaskSymbol(Optional rawInputMaskSymbol) + { + this.rawInputMaskSymbol = rawInputMaskSymbol; + return this; + } + protected AggregationNode build() { checkState(groupingSets != null, "No grouping sets defined; use globalGrouping/groupingKeys method"); @@ -496,7 +536,20 @@ protected AggregationNode build() preGroupedSymbols, step, hashSymbol, - groupIdSymbol); + groupIdSymbol, + rawInputMaskSymbol); + } + + public AggregationNode partialAggregation(Consumer aggregationBuilderConsumer) + { + return aggregation(aggregation -> aggregationBuilderConsumer.accept(aggregation + .step(Step.PARTIAL) + .rawInputMaskSymbol(rawInputMaskSymbol))); + } + + public AggregationBuilder finalAggregation() + { + return step(AggregationNode.Step.FINAL).rawInputMaskSymbol(symbol("rawInputMask", BOOLEAN)); } } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/sanity/TestValidateAggregationsWithDefaultValues.java b/core/trino-main/src/test/java/io/trino/sql/planner/sanity/TestValidateAggregationsWithDefaultValues.java index 68543fb05534..46750f085abb 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/sanity/TestValidateAggregationsWithDefaultValues.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/sanity/TestValidateAggregationsWithDefaultValues.java @@ -36,8 +36,6 @@ import static io.trino.SessionTestUtils.TEST_SESSION; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.sql.planner.TypeAnalyzer.createTestingTypeAnalyzer; -import static io.trino.sql.planner.plan.AggregationNode.Step.FINAL; -import static io.trino.sql.planner.plan.AggregationNode.Step.PARTIAL; import static io.trino.sql.planner.plan.AggregationNode.groupingSets; import static io.trino.sql.planner.plan.ExchangeNode.Scope.LOCAL; import static io.trino.sql.planner.plan.ExchangeNode.Scope.REMOTE; @@ -72,10 +70,9 @@ public void setup() public void testGloballyDistributedFinalAggregationInTheSameStageAsPartialAggregation() { PlanNode root = builder.aggregation( - af -> af.step(FINAL) + af -> af.finalAggregation() .groupingSets(groupingSets(ImmutableList.of(symbol), 2, ImmutableSet.of(0))) - .source(builder.aggregation(ap -> ap - .step(PARTIAL) + .source(af.partialAggregation(ap -> ap .groupingSets(groupingSets(ImmutableList.of(symbol), 2, ImmutableSet.of(0))) .source(tableScanNode)))); assertThatThrownBy(() -> validatePlan(root, false)) @@ -87,10 +84,9 @@ public void testGloballyDistributedFinalAggregationInTheSameStageAsPartialAggreg public void testSingleNodeFinalAggregationInTheSameStageAsPartialAggregation() { PlanNode root = builder.aggregation( - af -> af.step(FINAL) + af -> af.finalAggregation() .groupingSets(groupingSets(ImmutableList.of(symbol), 2, ImmutableSet.of(0))) - .source(builder.aggregation(ap -> ap - .step(PARTIAL) + .source(af.partialAggregation(ap -> ap .groupingSets(groupingSets(ImmutableList.of(symbol), 2, ImmutableSet.of(0))) .source(tableScanNode)))); assertThatThrownBy(() -> validatePlan(root, true)) @@ -102,10 +98,9 @@ public void testSingleNodeFinalAggregationInTheSameStageAsPartialAggregation() public void testSingleThreadFinalAggregationInTheSameStageAsPartialAggregation() { PlanNode root = builder.aggregation( - af -> af.step(FINAL) + af -> af.finalAggregation() .groupingSets(groupingSets(ImmutableList.of(symbol), 2, ImmutableSet.of(0))) - .source(builder.aggregation(ap -> ap - .step(PARTIAL) + .source(af.partialAggregation(ap -> ap .groupingSets(groupingSets(ImmutableList.of(symbol), 2, ImmutableSet.of(0))) .source(builder.values())))); validatePlan(root, true); @@ -116,15 +111,14 @@ public void testGloballyDistributedFinalAggregationSeparatedFromPartialAggregati { Symbol symbol = new Symbol("symbol"); PlanNode root = builder.aggregation( - af -> af.step(FINAL) + af -> af.finalAggregation() .groupingSets(groupingSets(ImmutableList.of(symbol), 2, ImmutableSet.of(0))) .source(builder.exchange(e -> e .type(REPARTITION) .scope(REMOTE) .fixedHashDistributionPartitioningScheme(ImmutableList.of(symbol), ImmutableList.of(symbol)) .addInputsSet(symbol) - .addSource(builder.aggregation(ap -> ap - .step(PARTIAL) + .addSource(af.partialAggregation(ap -> ap .groupingSets(groupingSets(ImmutableList.of(symbol), 2, ImmutableSet.of(0))) .source(tableScanNode)))))); validatePlan(root, false); @@ -135,15 +129,14 @@ public void testSingleNodeFinalAggregationSeparatedFromPartialAggregationByLocal { Symbol symbol = new Symbol("symbol"); PlanNode root = builder.aggregation( - af -> af.step(FINAL) + af -> af.finalAggregation() .groupingSets(groupingSets(ImmutableList.of(symbol), 2, ImmutableSet.of(0))) .source(builder.exchange(e -> e .type(REPARTITION) .scope(LOCAL) .fixedHashDistributionPartitioningScheme(ImmutableList.of(symbol), ImmutableList.of(symbol)) .addInputsSet(symbol) - .addSource(builder.aggregation(ap -> ap - .step(PARTIAL) + .addSource(af.partialAggregation(ap -> ap .groupingSets(groupingSets(ImmutableList.of(symbol), 2, ImmutableSet.of(0))) .source(tableScanNode)))))); validatePlan(root, true); @@ -154,7 +147,7 @@ public void testWithPartialAggregationBelowJoin() { Symbol symbol = new Symbol("symbol"); PlanNode root = builder.aggregation( - af -> af.step(FINAL) + af -> af.finalAggregation() .groupingSets(groupingSets(ImmutableList.of(symbol), 2, ImmutableSet.of(0))) .source(builder.join( INNER, @@ -163,8 +156,7 @@ public void testWithPartialAggregationBelowJoin() .scope(LOCAL) .fixedHashDistributionPartitioningScheme(ImmutableList.of(symbol), ImmutableList.of(symbol)) .addInputsSet(symbol) - .addSource(builder.aggregation(ap -> ap - .step(PARTIAL) + .addSource(af.partialAggregation(ap -> ap .groupingSets(groupingSets(ImmutableList.of(symbol), 2, ImmutableSet.of(0))) .source(tableScanNode)))), builder.values()))); @@ -176,12 +168,11 @@ public void testWithPartialAggregationBelowJoinWithoutSeparatingExchange() { Symbol symbol = new Symbol("symbol"); PlanNode root = builder.aggregation( - af -> af.step(FINAL) + af -> af.finalAggregation() .groupingSets(groupingSets(ImmutableList.of(symbol), 2, ImmutableSet.of(0))) .source(builder.join( INNER, - builder.aggregation(ap -> ap - .step(PARTIAL) + af.partialAggregation(ap -> ap .groupingSets(groupingSets(ImmutableList.of(symbol), 2, ImmutableSet.of(0))) .source(tableScanNode)), builder.values()))); diff --git a/testing/trino-benchmark/src/main/java/io/trino/benchmark/BenchmarkAggregationFunction.java b/testing/trino-benchmark/src/main/java/io/trino/benchmark/BenchmarkAggregationFunction.java index 0e00c363ab30..f4539b8fd6c1 100644 --- a/testing/trino-benchmark/src/main/java/io/trino/benchmark/BenchmarkAggregationFunction.java +++ b/testing/trino-benchmark/src/main/java/io/trino/benchmark/BenchmarkAggregationFunction.java @@ -51,6 +51,7 @@ public AggregatorFactory bind(List inputChannels) finalType, inputChannels, OptionalInt.empty(), + OptionalInt.empty(), true, ImmutableList.of()); } diff --git a/testing/trino-benchmark/src/main/java/io/trino/benchmark/HandTpchQuery1.java b/testing/trino-benchmark/src/main/java/io/trino/benchmark/HandTpchQuery1.java index 636988d87e61..1cc45db64d83 100644 --- a/testing/trino-benchmark/src/main/java/io/trino/benchmark/HandTpchQuery1.java +++ b/testing/trino-benchmark/src/main/java/io/trino/benchmark/HandTpchQuery1.java @@ -107,6 +107,8 @@ protected List createOperatorFactories() getColumnTypes("lineitem", "returnflag", "linestatus"), Ints.asList(0, 1), ImmutableList.of(), + ImmutableList.of(), + ImmutableList.of(), Step.SINGLE, ImmutableList.of( doubleSum.bind(ImmutableList.of(2)), diff --git a/testing/trino-benchmark/src/main/java/io/trino/benchmark/HashAggregationBenchmark.java b/testing/trino-benchmark/src/main/java/io/trino/benchmark/HashAggregationBenchmark.java index 8eec3fd5bd98..6fde7ffb7c4d 100644 --- a/testing/trino-benchmark/src/main/java/io/trino/benchmark/HashAggregationBenchmark.java +++ b/testing/trino-benchmark/src/main/java/io/trino/benchmark/HashAggregationBenchmark.java @@ -54,6 +54,8 @@ protected List createOperatorFactories() ImmutableList.of(tableTypes.get(0)), Ints.asList(0), ImmutableList.of(), + ImmutableList.of(), + ImmutableList.of(), Step.SINGLE, ImmutableList.of(doubleSum.bind(ImmutableList.of(1))), Optional.empty(), diff --git a/testing/trino-tests/src/test/java/io/trino/tests/TestAdaptivePartialAggregation.java b/testing/trino-tests/src/test/java/io/trino/tests/TestAdaptivePartialAggregation.java index 3b9fa208c216..91077940f984 100644 --- a/testing/trino-tests/src/test/java/io/trino/tests/TestAdaptivePartialAggregation.java +++ b/testing/trino-tests/src/test/java/io/trino/tests/TestAdaptivePartialAggregation.java @@ -17,6 +17,7 @@ import io.trino.testing.AbstractTestAggregations; import io.trino.testing.QueryRunner; import io.trino.tests.tpch.TpchQueryRunnerBuilder; +import org.testng.annotations.Test; public class TestAdaptivePartialAggregation extends AbstractTestAggregations @@ -31,4 +32,16 @@ protected QueryRunner createQueryRunner() "task.max-partial-aggregation-memory", "0B")) .build(); } + + @Test + public void testAggregationWithLambda() + { + // case with partial aggregation disabled adaptively. + // orderkey + 1 is needed to avoid streaming aggregation. + assertQuery( + "SELECT orderkey + 1, reduce_agg(orderkey, 1, (a, b) -> a * b, (a, b) -> a * b) " + + "FROM orders " + + "GROUP BY orderkey + 1", + "SELECT orderkey + 1, orderkey from orders"); + } }