Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix join and unnest planning to ensure that duplicate join prefixes are not used #13943

Merged
merged 2 commits into from
Mar 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,8 @@ private static RowSignature getCorrelateRowSignature(
RowSignature.builder().add(
BASE_UNNEST_OUTPUT_COLUMN,
Calcites.getColumnTypeForRelDataType(unnestedType)
).build()
).build(),
DruidJoinQueryRel.findExistingJoinPrefixes(leftQuery.getDataSource())
).rhs;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@
import org.apache.druid.sql.calcite.table.RowSignatures;

import javax.annotation.Nullable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
Expand Down Expand Up @@ -160,7 +162,12 @@ public DruidQuery toDruidQuery(final boolean finalizeAggregations)
rightDataSource = rightQuery.getDataSource();
}

final Pair<String, RowSignature> prefixSignaturePair = computeJoinRowSignature(leftSignature, rightSignature);

final Pair<String, RowSignature> prefixSignaturePair = computeJoinRowSignature(
leftSignature,
rightSignature,
findExistingJoinPrefixes(leftDataSource, rightDataSource)
);

VirtualColumnRegistry virtualColumnRegistry = VirtualColumnRegistry.create(
prefixSignaturePair.rhs,
Expand Down Expand Up @@ -380,13 +387,29 @@ public static boolean computeRightRequiresSubquery(final PlannerContext plannerC
&& DruidRels.druidTableIfLeafRel(right).filter(table -> table.getDataSource().isGlobal()).isPresent());
}

static Set<String> findExistingJoinPrefixes(DataSource... dataSources)
{
final ArrayList<DataSource> copy = new ArrayList<>(Arrays.asList(dataSources));

Set<String> prefixes = new HashSet<>();
while (!copy.isEmpty()) {
DataSource current = copy.remove(0);
copy.addAll(current.getChildren());
if (current instanceof JoinDataSource) {
JoinDataSource joiner = (JoinDataSource) current;
prefixes.add(joiner.getRightPrefix());
}
}
return prefixes;
}
/**
* Returns a Pair of "rightPrefix" (for JoinDataSource) and the signature of rows that will result from
* applying that prefix.
*/
static Pair<String, RowSignature> computeJoinRowSignature(
final RowSignature leftSignature,
final RowSignature rightSignature
final RowSignature rightSignature,
final Set<String> prefixes
)
{
final RowSignature.Builder signatureBuilder = RowSignature.builder();
Expand All @@ -395,8 +418,17 @@ static Pair<String, RowSignature> computeJoinRowSignature(
signatureBuilder.add(column, leftSignature.getColumnType(column).orElse(null));
}

// Need to include the "0" since findUnusedPrefixForDigits only guarantees safety for digit-initiated suffixes
final String rightPrefix = Calcites.findUnusedPrefixForDigits("j", leftSignature.getColumnNames()) + "0.";
StringBuilder base = new StringBuilder("j");
// the prefixes collection contains all known join prefixes, which might be in use for nested queries but not
// present in the top level row signatures
// loop until we are sure we got a new prefix
String maybePrefix;
do {
// Need to include the "0" since findUnusedPrefixForDigits only guarantees safety for digit-initiated suffixes
maybePrefix = Calcites.findUnusedPrefixForDigits(base.toString(), leftSignature.getColumnNames()) + "0.";
base.insert(0, "_");
} while (prefixes.contains(maybePrefix));
final String rightPrefix = maybePrefix;

for (final String column : rightSignature.getColumnNames()) {
signatureBuilder.add(rightPrefix + column, rightSignature.getColumnType(column).orElse(null));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
import org.apache.druid.query.extraction.SubstringDimExtractionFn;
import org.apache.druid.query.filter.AndDimFilter;
import org.apache.druid.query.filter.BoundDimFilter;
import org.apache.druid.query.filter.InDimFilter;
import org.apache.druid.query.filter.LikeDimFilter;
import org.apache.druid.query.filter.NotDimFilter;
import org.apache.druid.query.filter.OrDimFilter;
Expand Down Expand Up @@ -95,8 +96,10 @@
import org.junit.Test;
import org.junit.runner.RunWith;

import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;

Expand Down Expand Up @@ -4766,8 +4769,8 @@ public void testVirtualColumnOnMVFilterMultiJoinExpression(Map<String, Object> q
.context(queryContext)
.build()
),
"j0.",
equalsCondition(makeColumnExpression("v0"), makeColumnExpression("j0.v0")),
"_j0.",
equalsCondition(makeColumnExpression("v0"), makeColumnExpression("_j0.v0")),

Check notice

Code scanning / CodeQL

Deprecated method or constructor invocation

Invoking [CalciteTestBase.makeColumnExpression](1) should be avoided because it has been deprecated.

Check notice

Code scanning / CodeQL

Deprecated method or constructor invocation

Invoking [CalciteTestBase.makeColumnExpression](1) should be avoided because it has been deprecated.
JoinType.INNER
)
)
Expand All @@ -4778,7 +4781,7 @@ public void testVirtualColumnOnMVFilterMultiJoinExpression(Map<String, Object> q
ImmutableSet.of("a"),
true
))
.columns("dim3", "j0.dim3")
.columns("_j0.dim3", "dim3")
.context(queryContext)
.build()
),
Expand Down Expand Up @@ -5084,4 +5087,181 @@ public void testPlanWithInFilterMoreThanInSubQueryThreshold()
null
);
}

@Test
@Parameters(source = QueryContextForJoinProvider.class)
public void testRegressionFilteredAggregatorsSubqueryJoins(Map<String, Object> queryContext)
{
cannotVectorize();
testQuery(
"select\n" +
"count(*) filter (where trim(both from dim1) in (select dim2 from foo)),\n" +
"min(m1) filter (where 'A' not in (select m2 from foo))\n" +
"from foo as t0\n" +
"where __time in (select __time from foo)",
queryContext,
useDefault ?
ImmutableList.of(
Druids.newTimeseriesQueryBuilder()
.dataSource(
join(
join(
join(
new TableDataSource(CalciteTests.DATASOURCE1),
new QueryDataSource(
GroupByQuery.builder()
.setDataSource(CalciteTests.DATASOURCE1)
.setInterval(querySegmentSpec(Filtration.eternity()))
.setDimensions(
new DefaultDimensionSpec("__time", "d0", ColumnType.LONG)
)
.setGranularity(Granularities.ALL)
.setLimitSpec(NoopLimitSpec.instance())
.build()
),
"j0.",
equalsCondition(makeColumnExpression("__time"), makeColumnExpression("j0.d0")),

Check notice

Code scanning / CodeQL

Deprecated method or constructor invocation

Invoking [CalciteTestBase.makeColumnExpression](1) should be avoided because it has been deprecated.

Check notice

Code scanning / CodeQL

Deprecated method or constructor invocation

Invoking [CalciteTestBase.makeColumnExpression](1) should be avoided because it has been deprecated.
JoinType.INNER
),
new QueryDataSource(
GroupByQuery.builder()
.setDataSource(CalciteTests.DATASOURCE1)
.setInterval(querySegmentSpec(Filtration.eternity()))
.setVirtualColumns(expressionVirtualColumn("v0", "1", ColumnType.LONG))
.setDimensions(
new DefaultDimensionSpec("dim2", "d0", ColumnType.STRING),
new DefaultDimensionSpec("v0", "d1", ColumnType.LONG)
)
.setGranularity(Granularities.ALL)
.setLimitSpec(NoopLimitSpec.instance())
.build()
),
"_j0.",
"(trim(\"dim1\",' ') == \"_j0.d0\")",
JoinType.LEFT
),
new QueryDataSource(
GroupByQuery.builder()
.setDataSource(CalciteTests.DATASOURCE1)
.setInterval(querySegmentSpec(Filtration.eternity()))
.setVirtualColumns(expressionVirtualColumn("v0", "1", ColumnType.LONG))
.setDimFilter(selector("m2", "A", null))
.setDimensions(
new DefaultDimensionSpec("v0", "d0", ColumnType.LONG)
)
.setGranularity(Granularities.ALL)
.setLimitSpec(NoopLimitSpec.instance())
.build()
),
"__j0.",
"1",
JoinType.LEFT
)
)
.intervals(querySegmentSpec(Filtration.eternity()))
.aggregators(
new FilteredAggregatorFactory(
new CountAggregatorFactory("a0"),
and(
not(selector("_j0.d1", null, null)),
not(selector("dim1", null, null))
),
"a0"
),
new FilteredAggregatorFactory(
new FloatMinAggregatorFactory("a1", "m1"),
selector("__j0.d0", null, null),
"a1"
)
)
.context(queryContext)
.build()
) :
ImmutableList.of(
Druids.newTimeseriesQueryBuilder()
.dataSource(
join(
join(
join(
new TableDataSource(CalciteTests.DATASOURCE1),
new QueryDataSource(
GroupByQuery.builder()
.setDataSource(CalciteTests.DATASOURCE1)
.setInterval(querySegmentSpec(Filtration.eternity()))
.setDimensions(
new DefaultDimensionSpec("__time", "d0", ColumnType.LONG)
)
.setGranularity(Granularities.ALL)
.setLimitSpec(NoopLimitSpec.instance())
.build()
),
"j0.",
equalsCondition(makeColumnExpression("__time"), makeColumnExpression("j0.d0")),

Check notice

Code scanning / CodeQL

Deprecated method or constructor invocation

Invoking [CalciteTestBase.makeColumnExpression](1) should be avoided because it has been deprecated.

Check notice

Code scanning / CodeQL

Deprecated method or constructor invocation

Invoking [CalciteTestBase.makeColumnExpression](1) should be avoided because it has been deprecated.
JoinType.INNER
),
new QueryDataSource(
GroupByQuery.builder()
.setDataSource(CalciteTests.DATASOURCE1)
.setInterval(querySegmentSpec(Filtration.eternity()))
.setVirtualColumns(expressionVirtualColumn("v0", "1", ColumnType.LONG))
.setDimensions(
new DefaultDimensionSpec("dim2", "d0", ColumnType.STRING),
new DefaultDimensionSpec("v0", "d1", ColumnType.LONG)
)
.setGranularity(Granularities.ALL)
.setLimitSpec(NoopLimitSpec.instance())
.build()
),
"_j0.",
"(trim(\"dim1\",' ') == \"_j0.d0\")",
JoinType.LEFT
),
new QueryDataSource(
new TopNQueryBuilder().dataSource(CalciteTests.DATASOURCE1)
.intervals(querySegmentSpec(Filtration.eternity()))
.filters(new InDimFilter("m2", new HashSet<>(Arrays.asList(null, "A"))))
.virtualColumns(expressionVirtualColumn("v0", "notnull(\"m2\")", ColumnType.LONG))
.dimension(new DefaultDimensionSpec("v0", "d0", ColumnType.LONG))
.metric(new InvertedTopNMetricSpec(new DimensionTopNMetricSpec(null, StringComparators.LEXICOGRAPHIC)))
.aggregators(new CountAggregatorFactory("a0"))
.threshold(1)
.build()
),
"__j0.",
"1",
JoinType.LEFT
)
)
.intervals(querySegmentSpec(Filtration.eternity()))
.aggregators(
new FilteredAggregatorFactory(
new CountAggregatorFactory("a0"),
and(
not(selector("_j0.d1", null, null)),
not(selector("dim1", null, null))
),
"a0"
),
new FilteredAggregatorFactory(
new FloatMinAggregatorFactory("a1", "m1"),
or(
selector("__j0.a0", null, null),
not(
or(
not(expressionFilter("\"__j0.d0\"")),
not(selector("__j0.d0", null, null))
)
)
),
"a1"
)
)
.context(queryContext)
.build()
),
ImmutableList.of(
new Object[]{useDefault ? 1L : 2L, 1.0f}
)
);
}
}