Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Common expression elimination should also re-find the correct expression, during re-write. #9870

Closed
wiedld opened this issue Mar 29, 2024 · 1 comment · Fixed by #9871
Closed
Assignees
Labels
bug Something isn't working regression Something that used to work no longer does

Comments

@wiedld
Copy link
Contributor

wiedld commented Mar 29, 2024

Describe the bug

The common-expression-elimination optimizer errors when there are two aggr_fn(<short_circuit>) projections. Specifically, we have a reproducer with a failure when running:

statement ok
CREATE TABLE t1(
    time TIMESTAMP,
    load1 DOUBLE,
    load2 DOUBLE,
    host VARCHAR
) AS VALUES
  (to_timestamp_nanos(1527018806000000000), 1.1, 101, 'host1'),
  (to_timestamp_nanos(1527018806000000000), 2.2, 202, 'host2'),
  (to_timestamp_nanos(1527018806000000000), 3.3, 303, 'host3'),
  (to_timestamp_nanos(1527018806000000000), 1.1, 101, NULL)
;

# cannot have 2 projections with aggr(short_circuited), with the same short-circuited expr (e.g. CASE WHEN)
statement error DataFusion error: Optimizer rule 'common_sub_expression_eliminate' failed
select t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] as host, sum((case when t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] is not null then t2."struct(t1.time,t1.load1,t1.load2,t1.host)" end)['c1']), sum((case when t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] is not null then t2."struct(t1.time,t1.load1,t1.load2,t1.host)" end)['c2']) from (select struct(time,load1,load2,host) from t1) t2 where t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] IS NOT NULL group by t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] order by host;

We isolated this error to how the IdArray is generated and used. The IdArray is generated by insertion-at-index during traversal of a first visitor. This IdArray is then read during the second visitor, also by an index, but this index is incremented differently as this second visitor traverses.

As such, the second visitor finds the wrong expr, and then inserts the wrong expression. We could make small changes to fix our bug, but then broke other statements. We feel that the index-based lookup may be inherently fragile to slight changes in traversal patterns across the two visitors; as such, we have a proposed alternative which will be up in PR shortly.

To Reproduce

Full reproducer:

statement ok
CREATE TABLE t1(
    time TIMESTAMP,
    load1 DOUBLE,
    load2 DOUBLE,
    host VARCHAR
) AS VALUES
  (to_timestamp_nanos(1527018806000000000), 1.1, 101, 'host1'),
  (to_timestamp_nanos(1527018806000000000), 2.2, 202, 'host2'),
  (to_timestamp_nanos(1527018806000000000), 3.3, 303, 'host3'),
  (to_timestamp_nanos(1527018806000000000), 1.1, 101, NULL)
;

# struct scalar function with columns
query ?
select struct(time,load1,load2,host) from t1;
----
{c0: 2018-05-22T19:53:26, c1: 1.1, c2: 101.0, c3: host1}
{c0: 2018-05-22T19:53:26, c1: 2.2, c2: 202.0, c3: host2}
{c0: 2018-05-22T19:53:26, c1: 3.3, c2: 303.0, c3: host3}
{c0: 2018-05-22T19:53:26, c1: 1.1, c2: 101.0, c3: }

# can have an aggregate function with an inner coalesce
query TR
select t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] as host, sum(coalesce(t2."struct(t1.time,t1.load1,t1.load2,t1.host)")['c1']) from (select struct(time,load1,load2,host) from t1) t2 where t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] IS NOT NULL group by t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] order by host;
----
host1 1.1
host2 2.2
host3 3.3

# can have an aggregate function with an inner CASE WHEN
query TR
select t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] as host, sum((case when t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] is not null then t2."struct(t1.time,t1.load1,t1.load2,t1.host)" end)['c2']) from (select struct(time,load1,load2,host) from t1) t2 where t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] IS NOT NULL group by t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] order by host;
----
host1 101
host2 202
host3 303

# can have 2 projections with aggr(short_circuited), with different short-circuited expr
query TRR
select t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] as host, sum(coalesce(t2."struct(t1.time,t1.load1,t1.load2,t1.host)")['c1']), sum((case when t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] is not null then t2."struct(t1.time,t1.load1,t1.load2,t1.host)" end)['c2']) from (select struct(time,load1,load2,host) from t1) t2 where t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] IS NOT NULL group by t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] order by host;
----
host1 1.1 101
host2 2.2 202
host3 3.3 303

# can have 2 projections with aggr(short_circuited), with the same short-circuited expr (e.g. CASE WHEN)
statement error DataFusion error: Optimizer rule 'common_sub_expression_eliminate' failed
select t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] as host, sum((case when t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] is not null then t2."struct(t1.time,t1.load1,t1.load2,t1.host)" end)['c1']), sum((case when t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] is not null then t2."struct(t1.time,t1.load1,t1.load2,t1.host)" end)['c2']) from (select struct(time,load1,load2,host) from t1) t2 where t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] IS NOT NULL group by t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] order by host;

# can have 2 projections with aggr(short_circuited), with the same short-circuited expr (e.g. coalesce)
statement error DataFusion error: Optimizer rule 'common_sub_expression_eliminate' failed
select t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] as host, sum(coalesce(t2."struct(t1.time,t1.load1,t1.load2,t1.host)")['c1']), sum(coalesce(t2."struct(t1.time,t1.load1,t1.load2,t1.host)")['c2']) from (select struct(time,load1,load2,host) from t1) t2 where t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] IS NOT NULL group by t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] order by host;

Expected behavior

Should succeed on all of the provided test cases. Should not lookup with wrong expr, and insert the wrong expression, during the common-expression-elimination re-writer.

Additional context

No response

@alamb
Copy link
Contributor

alamb commented Mar 30, 2024

I ran this reproducer on a 36.0.0 build and it works, so thus I think this is a regression:

(venv) andrewlamb@Andrews-MacBook-Pro:~/Software/arrow-datafusion$ ~/Software/DataFusionArchive/datafusion-cli-36.0.0
DataFusion CLI v36.0.0
❯ CREATE TABLE t1(
    time TIMESTAMP,
    load1 DOUBLE,
    load2 DOUBLE,
    host VARCHAR
) AS VALUES
  (to_timestamp_nanos(1527018806000000000), 1.1, 101, 'host1'),
  (to_timestamp_nanos(1527018806000000000), 2.2, 202, 'host2'),
  (to_timestamp_nanos(1527018806000000000), 3.3, 303, 'host3'),
  (to_timestamp_nanos(1527018806000000000), 1.1, 101, NULL)
;
0 rows in set. Query took 0.028 seconds.

❯ select t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] as host, sum((case when t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] is not null then t2."struct(t1.time,t1.load1,t1.load2,t1.host)" end)['c1']), sum((case when t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] is not null then t2."struct(t1.time,t1.load1,t1.load2,t1.host)" end)['c2']) from (select struct(time,load1,load2,host) from t1) t2 where t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] IS NOT NULL group by t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] order by host;
+-------+-----------------------------------------------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------------+
| host  | SUM(CASE WHEN (t2.struct(t1.time,t1.load1,t1.load2,t1.host))[c3] IS NOT NULL THEN t2.struct(t1.time,t1.load1,t1.load2,t1.host) END[c1]) | SUM(CASE WHEN (t2.struct(t1.time,t1.load1,t1.load2,t1.host))[c3] IS NOT NULL THEN t2.struct(t1.time,t1.load1,t1.load2,t1.host) END[c2]) |
+-------+-----------------------------------------------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------------+
| host1 | 1.1                                                                                                                                     | 101.0                                                                                                                                   |
| host2 | 2.2                                                                                                                                     | 202.0                                                                                                                                   |
| host3 | 3.3                                                                                                                                     | 303.0                                                                                                                                   |
+-------+-----------------------------------------------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------------+
3 rows in set. Query took 0.031 seconds.

I believe this was also introduced as part of the TreeNode refactor in #8891, similarly to #9678

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working regression Something that used to work no longer does
Projects
None yet
2 participants