From 3e1850db1b7effea16b9130bd329a8240028f8a2 Mon Sep 17 00:00:00 2001 From: Maron Montano Date: Sun, 8 Sep 2024 11:51:06 +0800 Subject: [PATCH] Minor: Add getter for logical optimizer rules (#12379) * feat: new getter method for optimizer rules in parity with its physical counterpart * style: use rust_lint * chore: make struct local to the test scope --- .../core/src/execution/session_state.rs | 34 +++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index 675ac798bf4e..5e8d22b6ccbc 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -801,6 +801,11 @@ impl SessionState { &mut self.config } + /// Return the logical optimizers + pub fn optimizers(&self) -> &[Arc] { + &self.optimizer.rules + } + /// Return the physical optimizers pub fn physical_optimizers(&self) -> &[Arc] { &self.physical_optimizers.rules @@ -1826,6 +1831,8 @@ mod tests { use datafusion_common::Result; use datafusion_execution::config::SessionConfig; use datafusion_expr::Expr; + use datafusion_optimizer::optimizer::OptimizerRule; + use datafusion_optimizer::Optimizer; use datafusion_sql::planner::{PlannerContext, SqlToRel}; use std::collections::HashMap; use std::sync::Arc; @@ -1922,4 +1929,31 @@ mod tests { assert!(new_state.catalog_list().catalog(&default_catalog).is_none()); Ok(()) } + + #[test] + fn test_session_state_with_optimizer_rules() { + struct DummyRule {} + + impl OptimizerRule for DummyRule { + fn name(&self) -> &str { + "dummy_rule" + } + } + // test building sessions with fresh set of rules + let state = SessionStateBuilder::new() + .with_optimizer_rules(vec![Arc::new(DummyRule {})]) + .build(); + + assert_eq!(state.optimizers().len(), 1); + + // test adding rules to default recommendations + let state = SessionStateBuilder::new() + .with_optimizer_rule(Arc::new(DummyRule {})) + .build(); + + assert_eq!( + state.optimizers().len(), + Optimizer::default().rules.len() + 1 + ); + } }