diff --git a/datafusion-cli/src/catalog.rs b/datafusion-cli/src/catalog.rs index c11eb3280c20f..b83f659756105 100644 --- a/datafusion-cli/src/catalog.rs +++ b/datafusion-cli/src/catalog.rs @@ -29,6 +29,7 @@ use datafusion::datasource::listing::{ use datafusion::datasource::TableProvider; use datafusion::error::Result; use datafusion::execution::context::SessionState; +use datafusion::execution::session_state::SessionStateBuilder; use async_trait::async_trait; use dirs::home_dir; @@ -162,6 +163,7 @@ impl SchemaProvider for DynamicFileSchemaProvider { .ok_or_else(|| plan_datafusion_err!("locking error"))? .read() .clone(); + let mut builder = SessionStateBuilder::from(state.clone()); let optimized_name = substitute_tilde(name.to_owned()); let table_url = ListingTableUrl::parse(optimized_name.as_str())?; let scheme = table_url.scheme(); @@ -178,13 +180,18 @@ impl SchemaProvider for DynamicFileSchemaProvider { // to any command options so the only choice is to use an empty collection match scheme { "s3" | "oss" | "cos" => { - state = state.add_table_options_extension(AwsOptions::default()); + if let Some(table_options) = builder.table_options() { + table_options.extensions.insert(AwsOptions::default()) + } } "gs" | "gcs" => { - state = state.add_table_options_extension(GcpOptions::default()) + if let Some(table_options) = builder.table_options() { + table_options.extensions.insert(GcpOptions::default()) + } } _ => {} }; + state = builder.build(); let store = get_object_store( &state, table_url.scheme(), diff --git a/datafusion-examples/examples/custom_file_format.rs b/datafusion-examples/examples/custom_file_format.rs index fe936418bce4a..bdb702375c945 100644 --- a/datafusion-examples/examples/custom_file_format.rs +++ b/datafusion-examples/examples/custom_file_format.rs @@ -22,6 +22,7 @@ use arrow::{ datatypes::UInt64Type, }; use arrow_schema::{DataType, Field, Schema, SchemaRef}; +use datafusion::execution::session_state::SessionStateBuilder; use datafusion::{ datasource::{ file_format::{ @@ -32,9 +33,9 @@ use datafusion::{ MemTable, }, error::Result, - execution::{context::SessionState, runtime_env::RuntimeEnv}, + execution::context::SessionState, physical_plan::ExecutionPlan, - prelude::{SessionConfig, SessionContext}, + prelude::SessionContext, }; use datafusion_common::{GetExt, Statistics}; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortRequirement}; @@ -176,9 +177,7 @@ impl GetExt for TSVFileFactory { #[tokio::main] async fn main() -> Result<()> { // Create a new context with the default configuration - let config = SessionConfig::new(); - let runtime = RuntimeEnv::default(); - let mut state = SessionState::new_with_config_rt(config, Arc::new(runtime)); + let mut state = SessionStateBuilder::new().with_default_features().build(); // Register the custom file format let file_format = Arc::new(TSVFileFactory::new()); diff --git a/datafusion/core/src/datasource/file_format/csv.rs b/datafusion/core/src/datasource/file_format/csv.rs index 92cb11e2b47a4..baeaf51fb56d1 100644 --- a/datafusion/core/src/datasource/file_format/csv.rs +++ b/datafusion/core/src/datasource/file_format/csv.rs @@ -632,6 +632,7 @@ mod tests { use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion_expr::{col, lit}; + use crate::execution::session_state::SessionStateBuilder; use chrono::DateTime; use object_store::local::LocalFileSystem; use object_store::path::Path; @@ -814,7 +815,11 @@ mod tests { let runtime = Arc::new(RuntimeEnv::new(RuntimeConfig::new()).unwrap()); let mut cfg = SessionConfig::new(); cfg.options_mut().catalog.has_header = true; - let session_state = SessionState::new_with_config_rt(cfg, runtime); + let session_state = SessionStateBuilder::new() + .with_config(cfg) + .with_runtime_env(runtime) + .with_default_features() + .build(); let integration = LocalFileSystem::new_with_prefix(arrow_test_data()).unwrap(); let path = Path::from("csv/aggregate_test_100.csv"); let csv = CsvFormat::default().with_has_header(true); diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 4b9e3e843341a..640a9b14a65f1 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -73,6 +73,7 @@ use object_store::ObjectStore; use parking_lot::RwLock; use url::Url; +use crate::execution::session_state::SessionStateBuilder; pub use datafusion_execution::config::SessionConfig; pub use datafusion_execution::TaskContext; pub use datafusion_expr::execution_props::ExecutionProps; @@ -294,7 +295,11 @@ impl SessionContext { /// all `SessionContext`'s should be configured with the /// same `RuntimeEnv`. pub fn new_with_config_rt(config: SessionConfig, runtime: Arc) -> Self { - let state = SessionState::new_with_config_rt(config, runtime); + let state = SessionStateBuilder::new() + .with_config(config) + .with_runtime_env(runtime) + .with_default_features() + .build(); Self::new_with_state(state) } @@ -315,7 +320,7 @@ impl SessionContext { } /// Creates a new `SessionContext` using the provided [`SessionState`] - #[deprecated(since = "32.0.0", note = "Use SessionState::new_with_state")] + #[deprecated(since = "32.0.0", note = "Use SessionContext::new_with_state")] pub fn with_state(state: SessionState) -> Self { Self::new_with_state(state) } @@ -1574,6 +1579,7 @@ mod tests { use datafusion_common_runtime::SpawnedTask; use crate::catalog::schema::SchemaProvider; + use crate::execution::session_state::SessionStateBuilder; use crate::physical_planner::PhysicalPlanner; use async_trait::async_trait; use tempfile::TempDir; @@ -1707,7 +1713,11 @@ mod tests { .set_str("datafusion.catalog.location", url.as_str()) .set_str("datafusion.catalog.format", "CSV") .set_str("datafusion.catalog.has_header", "true"); - let session_state = SessionState::new_with_config_rt(cfg, runtime); + let session_state = SessionStateBuilder::new() + .with_config(cfg) + .with_runtime_env(runtime) + .with_default_features() + .build(); let ctx = SessionContext::new_with_state(session_state); ctx.refresh_catalogs().await?; @@ -1733,9 +1743,12 @@ mod tests { #[tokio::test] async fn custom_query_planner() -> Result<()> { let runtime = Arc::new(RuntimeEnv::default()); - let session_state = - SessionState::new_with_config_rt(SessionConfig::new(), runtime) - .with_query_planner(Arc::new(MyQueryPlanner {})); + let session_state = SessionStateBuilder::new() + .with_config(SessionConfig::new()) + .with_runtime_env(runtime) + .with_default_features() + .with_query_planner(Arc::new(MyQueryPlanner {})) + .build(); let ctx = SessionContext::new_with_state(session_state); let df = ctx.sql("SELECT 1").await?; diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index dbfba9ea93521..75eef43454873 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -77,6 +77,8 @@ use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_plan::ExecutionPlan; use datafusion_sql::parser::{DFParser, Statement}; use datafusion_sql::planner::{ContextProvider, ParserOptions, PlannerContext, SqlToRel}; +use itertools::Itertools; +use log::{debug, info}; use sqlparser::ast::Expr as SQLExpr; use sqlparser::dialect::dialect_from_str; use std::collections::hash_map::Entry; @@ -89,9 +91,29 @@ use uuid::Uuid; /// Execution context for registering data sources and executing queries. /// See [`SessionContext`] for a higher level API. /// +/// Use the [`SessionStateBuilder`] to build a SessionState object. +/// +/// ``` +/// use datafusion::prelude::*; +/// # use datafusion::{error::Result, assert_batches_eq}; +/// # use datafusion::execution::session_state::SessionStateBuilder; +/// # use datafusion_execution::runtime_env::RuntimeEnv; +/// # use std::sync::Arc; +/// # #[tokio::main] +/// # async fn main() -> Result<()> { +/// let state = SessionStateBuilder::new() +/// .with_config(SessionConfig::new()) +/// .with_runtime_env(Arc::new(RuntimeEnv::default())) +/// .with_default_features() +/// .build(); +/// Ok(()) +/// # } +/// ``` +/// /// Note that there is no `Default` or `new()` for SessionState, /// to avoid accidentally running queries or other operations without passing through -/// the [`SessionConfig`] or [`RuntimeEnv`]. See [`SessionContext`]. +/// the [`SessionConfig`] or [`RuntimeEnv`]. See [`SessionStateBuilder`] and +/// [`SessionContext`]. /// /// [`SessionContext`]: crate::execution::context::SessionContext #[derive(Clone)] @@ -140,7 +162,6 @@ pub struct SessionState { table_factories: HashMap>, /// Runtime environment runtime_env: Arc, - /// [FunctionFactory] to support pluggable user defined function handler. /// /// It will be invoked on `CREATE FUNCTION` statements. @@ -153,6 +174,7 @@ impl Debug for SessionState { f.debug_struct("SessionState") .field("session_id", &self.session_id) .field("analyzer", &"...") + .field("expr_planners", &"...") .field("optimizer", &"...") .field("physical_optimizers", &"...") .field("query_planner", &"...") @@ -175,193 +197,56 @@ impl Debug for SessionState { impl SessionState { /// Returns new [`SessionState`] using the provided /// [`SessionConfig`] and [`RuntimeEnv`]. + #[deprecated(since = "40.0.0", note = "Use SessionStateBuilder")] pub fn new_with_config_rt(config: SessionConfig, runtime: Arc) -> Self { - let catalog_list = - Arc::new(MemoryCatalogProviderList::new()) as Arc; - Self::new_with_config_rt_and_catalog_list(config, runtime, catalog_list) + SessionStateBuilder::new() + .with_config(config) + .with_runtime_env(runtime) + .with_default_features() + .build() } /// Returns new [`SessionState`] using the provided /// [`SessionConfig`] and [`RuntimeEnv`]. - #[deprecated(since = "32.0.0", note = "Use SessionState::new_with_config_rt")] + #[deprecated(since = "32.0.0", note = "Use SessionStateBuilder")] pub fn with_config_rt(config: SessionConfig, runtime: Arc) -> Self { - Self::new_with_config_rt(config, runtime) + SessionStateBuilder::new() + .with_config(config) + .with_runtime_env(runtime) + .with_default_features() + .build() } /// Returns new [`SessionState`] using the provided /// [`SessionConfig`], [`RuntimeEnv`], and [`CatalogProviderList`] + #[deprecated(since = "40.0.0", note = "Use SessionStateBuilder")] pub fn new_with_config_rt_and_catalog_list( config: SessionConfig, runtime: Arc, catalog_list: Arc, ) -> Self { - let session_id = Uuid::new_v4().to_string(); - - // Create table_factories for all default formats - let mut table_factories: HashMap> = - HashMap::new(); - #[cfg(feature = "parquet")] - table_factories.insert("PARQUET".into(), Arc::new(DefaultTableFactory::new())); - table_factories.insert("CSV".into(), Arc::new(DefaultTableFactory::new())); - table_factories.insert("JSON".into(), Arc::new(DefaultTableFactory::new())); - table_factories.insert("NDJSON".into(), Arc::new(DefaultTableFactory::new())); - table_factories.insert("AVRO".into(), Arc::new(DefaultTableFactory::new())); - table_factories.insert("ARROW".into(), Arc::new(DefaultTableFactory::new())); - - if config.create_default_catalog_and_schema() { - let default_catalog = MemoryCatalogProvider::new(); - - default_catalog - .register_schema( - &config.options().catalog.default_schema, - Arc::new(MemorySchemaProvider::new()), - ) - .expect("memory catalog provider can register schema"); - - Self::register_default_schema( - &config, - &table_factories, - &runtime, - &default_catalog, - ); - - catalog_list.register_catalog( - config.options().catalog.default_catalog.clone(), - Arc::new(default_catalog), - ); - } - - let expr_planners: Vec> = vec![ - Arc::new(functions::core::planner::CoreFunctionPlanner::default()), - // register crate of array expressions (if enabled) - #[cfg(feature = "array_expressions")] - Arc::new(functions_array::planner::ArrayFunctionPlanner), - #[cfg(feature = "array_expressions")] - Arc::new(functions_array::planner::FieldAccessPlanner), - #[cfg(any( - feature = "datetime_expressions", - feature = "unicode_expressions" - ))] - Arc::new(functions::planner::UserDefinedFunctionPlanner), - ]; - - let mut new_self = SessionState { - session_id, - analyzer: Analyzer::new(), - expr_planners, - optimizer: Optimizer::new(), - physical_optimizers: PhysicalOptimizer::new(), - query_planner: Arc::new(DefaultQueryPlanner {}), - catalog_list, - table_functions: HashMap::new(), - scalar_functions: HashMap::new(), - aggregate_functions: HashMap::new(), - window_functions: HashMap::new(), - serializer_registry: Arc::new(EmptySerializerRegistry), - file_formats: HashMap::new(), - table_options: TableOptions::default_from_session_config(config.options()), - config, - execution_props: ExecutionProps::new(), - runtime_env: runtime, - table_factories, - function_factory: None, - }; - - #[cfg(feature = "parquet")] - if let Err(e) = - new_self.register_file_format(Arc::new(ParquetFormatFactory::new()), false) - { - log::info!("Unable to register default ParquetFormat: {e}") - }; - - if let Err(e) = - new_self.register_file_format(Arc::new(JsonFormatFactory::new()), false) - { - log::info!("Unable to register default JsonFormat: {e}") - }; - - if let Err(e) = - new_self.register_file_format(Arc::new(CsvFormatFactory::new()), false) - { - log::info!("Unable to register default CsvFormat: {e}") - }; - - if let Err(e) = - new_self.register_file_format(Arc::new(ArrowFormatFactory::new()), false) - { - log::info!("Unable to register default ArrowFormat: {e}") - }; - - if let Err(e) = - new_self.register_file_format(Arc::new(AvroFormatFactory::new()), false) - { - log::info!("Unable to register default AvroFormat: {e}") - }; - - // register built in functions - functions::register_all(&mut new_self) - .expect("can not register built in functions"); - - // register crate of array expressions (if enabled) - #[cfg(feature = "array_expressions")] - functions_array::register_all(&mut new_self) - .expect("can not register array expressions"); - - functions_aggregate::register_all(&mut new_self) - .expect("can not register aggregate functions"); - - new_self + SessionStateBuilder::new() + .with_config(config) + .with_runtime_env(runtime) + .with_catalog_list(catalog_list) + .with_default_features() + .build() } + /// Returns new [`SessionState`] using the provided /// [`SessionConfig`] and [`RuntimeEnv`]. - #[deprecated( - since = "32.0.0", - note = "Use SessionState::new_with_config_rt_and_catalog_list" - )] + #[deprecated(since = "32.0.0", note = "Use SessionStateBuilder")] pub fn with_config_rt_and_catalog_list( config: SessionConfig, runtime: Arc, catalog_list: Arc, ) -> Self { - Self::new_with_config_rt_and_catalog_list(config, runtime, catalog_list) - } - fn register_default_schema( - config: &SessionConfig, - table_factories: &HashMap>, - runtime: &Arc, - default_catalog: &MemoryCatalogProvider, - ) { - let url = config.options().catalog.location.as_ref(); - let format = config.options().catalog.format.as_ref(); - let (url, format) = match (url, format) { - (Some(url), Some(format)) => (url, format), - _ => return, - }; - let url = url.to_string(); - let format = format.to_string(); - - let url = Url::parse(url.as_str()).expect("Invalid default catalog location!"); - let authority = match url.host_str() { - Some(host) => format!("{}://{}", url.scheme(), host), - None => format!("{}://", url.scheme()), - }; - let path = &url.as_str()[authority.len()..]; - let path = object_store::path::Path::parse(path).expect("Can't parse path"); - let store = ObjectStoreUrl::parse(authority.as_str()) - .expect("Invalid default catalog url"); - let store = match runtime.object_store(store) { - Ok(store) => store, - _ => return, - }; - let factory = match table_factories.get(format.as_str()) { - Some(factory) => factory, - _ => return, - }; - let schema = - ListingSchemaProvider::new(authority, path, factory.clone(), store, format); - let _ = default_catalog - .register_schema("default", Arc::new(schema)) - .expect("Failed to register default schema"); + SessionStateBuilder::new() + .with_config(config) + .with_runtime_env(runtime) + .with_catalog_list(catalog_list) + .with_default_features() + .build() } pub(crate) fn resolve_table_ref( @@ -400,12 +285,14 @@ impl SessionState { }) } + #[deprecated(since = "40.0.0", note = "Use SessionStateBuilder")] /// Replace the random session id. pub fn with_session_id(mut self, session_id: String) -> Self { self.session_id = session_id; self } + #[deprecated(since = "40.0.0", note = "Use SessionStateBuilder")] /// override default query planner with `query_planner` pub fn with_query_planner( mut self, @@ -415,6 +302,7 @@ impl SessionState { self } + #[deprecated(since = "40.0.0", note = "Use SessionStateBuilder")] /// Override the [`AnalyzerRule`]s optimizer plan rules. pub fn with_analyzer_rules( mut self, @@ -424,6 +312,7 @@ impl SessionState { self } + #[deprecated(since = "40.0.0", note = "Use SessionStateBuilder")] /// Replace the entire list of [`OptimizerRule`]s used to optimize plans pub fn with_optimizer_rules( mut self, @@ -433,6 +322,7 @@ impl SessionState { self } + #[deprecated(since = "40.0.0", note = "Use SessionStateBuilder")] /// Replace the entire list of [`PhysicalOptimizerRule`]s used to optimize plans pub fn with_physical_optimizer_rules( mut self, @@ -452,6 +342,7 @@ impl SessionState { self } + #[deprecated(since = "40.0.0", note = "Use SessionStateBuilder")] /// Add `optimizer_rule` to the end of the list of /// [`OptimizerRule`]s used to rewrite queries. pub fn add_optimizer_rule( @@ -472,6 +363,7 @@ impl SessionState { self.optimizer.rules.push(optimizer_rule); } + #[deprecated(since = "40.0.0", note = "Use SessionStateBuilder")] /// Add `physical_optimizer_rule` to the end of the list of /// [`PhysicalOptimizerRule`]s used to rewrite queries. pub fn add_physical_optimizer_rule( @@ -482,6 +374,7 @@ impl SessionState { self } + #[deprecated(since = "40.0.0", note = "Use SessionStateBuilder")] /// Adds a new [`ConfigExtension`] to TableOptions pub fn add_table_options_extension( mut self, @@ -491,6 +384,7 @@ impl SessionState { self } + #[deprecated(since = "40.0.0", note = "Use SessionStateBuilder")] /// Registers a [`FunctionFactory`] to handle `CREATE FUNCTION` statements pub fn with_function_factory( mut self, @@ -505,6 +399,7 @@ impl SessionState { self.function_factory = Some(function_factory); } + #[deprecated(since = "40.0.0", note = "Use SessionStateBuilder")] /// Replace the extension [`SerializerRegistry`] pub fn with_serializer_registry( mut self, @@ -858,19 +753,20 @@ impl SessionState { &self.table_options } - /// Return mutable table opptions + /// Return mutable table options pub fn table_options_mut(&mut self) -> &mut TableOptions { &mut self.table_options } - /// Registers a [`ConfigExtension`] as a table option extention that can be + /// Registers a [`ConfigExtension`] as a table option extension that can be /// referenced from SQL statements executed against this context. pub fn register_table_options_extension(&mut self, extension: T) { self.table_options.extensions.insert(extension) } - /// Adds or updates a [FileFormatFactory] which can be used with COPY TO or CREATE EXTERNAL TABLE statements for reading - /// and writing files of custom formats. + /// Adds or updates a [FileFormatFactory] which can be used with COPY TO or + /// CREATE EXTERNAL TABLE statements for reading and writing files of custom + /// formats. pub fn register_file_format( &mut self, file_format: Arc, @@ -950,7 +846,7 @@ impl SessionState { ); } - /// Deregsiter a user defined table function + /// Deregister a user defined table function pub fn deregister_udtf( &mut self, name: &str, @@ -974,6 +870,733 @@ impl SessionState { } } +/// A builder to be used for building [`SessionState`]'s. Defaults will +/// be used for all values unless explicitly provided. +/// +/// See example on [`SessionState`] +pub struct SessionStateBuilder { + session_id: Option, + analyzer: Option, + expr_planners: Option>>, + optimizer: Option, + physical_optimizers: Option, + query_planner: Option>, + catalog_list: Option>, + table_functions: Option>>, + scalar_functions: Option>>, + aggregate_functions: Option>>, + window_functions: Option>>, + serializer_registry: Option>, + file_formats: Option>>, + config: Option, + table_options: Option, + execution_props: Option, + table_factories: Option>>, + runtime_env: Option>, + function_factory: Option>, + // fields to support convenience functions + analyzer_rules: Option>>, + optimizer_rules: Option>>, + physical_optimizer_rules: Option>>, +} + +impl SessionStateBuilder { + /// Returns a new [`SessionStateBuilder`] with no options set. + pub fn new() -> Self { + Self { + session_id: None, + analyzer: None, + expr_planners: None, + optimizer: None, + physical_optimizers: None, + query_planner: None, + catalog_list: None, + table_functions: None, + scalar_functions: None, + aggregate_functions: None, + window_functions: None, + serializer_registry: None, + file_formats: None, + table_options: None, + config: None, + execution_props: None, + table_factories: None, + runtime_env: None, + function_factory: None, + // fields to support convenience functions + analyzer_rules: None, + optimizer_rules: None, + physical_optimizer_rules: None, + } + } + + /// Returns a new [SessionStateBuilder] based on an existing [SessionState] + /// The session id for the new builder will be unset; all other fields will + /// be cloned from what is set in the provided session state + pub fn new_from_existing(existing: SessionState) -> Self { + Self { + session_id: None, + analyzer: Some(existing.analyzer), + expr_planners: Some(existing.expr_planners), + optimizer: Some(existing.optimizer), + physical_optimizers: Some(existing.physical_optimizers), + query_planner: Some(existing.query_planner), + catalog_list: Some(existing.catalog_list), + table_functions: Some(existing.table_functions), + scalar_functions: Some(existing.scalar_functions.into_values().collect_vec()), + aggregate_functions: Some( + existing.aggregate_functions.into_values().collect_vec(), + ), + window_functions: Some(existing.window_functions.into_values().collect_vec()), + serializer_registry: Some(existing.serializer_registry), + file_formats: Some(existing.file_formats.into_values().collect_vec()), + config: Some(existing.config), + table_options: Some(existing.table_options), + execution_props: Some(existing.execution_props), + table_factories: Some(existing.table_factories), + runtime_env: Some(existing.runtime_env), + function_factory: existing.function_factory, + + // fields to support convenience functions + analyzer_rules: None, + optimizer_rules: None, + physical_optimizer_rules: None, + } + } + + /// Set defaults for table_factories, file formats, expr_planners and builtin + /// scalar and aggregate functions. + pub fn with_default_features(mut self) -> Self { + self.table_factories = Some(SessionStateDefaults::default_table_factories()); + self.file_formats = Some(SessionStateDefaults::default_file_formats()); + self.expr_planners = Some(SessionStateDefaults::default_expr_planners()); + self.scalar_functions = Some(SessionStateDefaults::default_scalar_functions()); + self.aggregate_functions = + Some(SessionStateDefaults::default_aggregate_functions()); + self + } + + /// Set the session id. + pub fn with_session_id(mut self, session_id: String) -> Self { + self.session_id = Some(session_id); + self + } + + /// Set the [`AnalyzerRule`]s optimizer plan rules. + pub fn with_analyzer_rules( + mut self, + rules: Vec>, + ) -> Self { + self.analyzer = Some(Analyzer::with_rules(rules)); + self + } + + /// Add `analyzer_rule` to the end of the list of + /// [`AnalyzerRule`]s used to rewrite queries. + pub fn with_analyzer_rule( + mut self, + analyzer_rule: Arc, + ) -> Self { + let mut rules = self.analyzer_rules.unwrap_or_default(); + rules.push(analyzer_rule); + self.analyzer_rules = Some(rules); + self + } + + /// Set the [`OptimizerRule`]s used to optimize plans. + pub fn with_optimizer_rules( + mut self, + rules: Vec>, + ) -> Self { + self.optimizer = Some(Optimizer::with_rules(rules)); + self + } + + /// Add `optimizer_rule` to the end of the list of + /// [`OptimizerRule`]s used to rewrite queries. + pub fn with_optimizer_rule( + mut self, + optimizer_rule: Arc, + ) -> Self { + let mut rules = self.optimizer_rules.unwrap_or_default(); + rules.push(optimizer_rule); + self.optimizer_rules = Some(rules); + self + } + + /// Set the [`ExprPlanner`]s used to customize the behavior of the SQL planner. + pub fn with_expr_planners( + mut self, + expr_planners: Vec>, + ) -> Self { + self.expr_planners = Some(expr_planners); + self + } + + /// Set tje [`PhysicalOptimizerRule`]s used to optimize plans. + pub fn with_physical_optimizer_rules( + mut self, + physical_optimizers: Vec>, + ) -> Self { + self.physical_optimizers = + Some(PhysicalOptimizer::with_rules(physical_optimizers)); + self + } + + /// Add `physical_optimizer_rule` to the end of the list of + /// [`PhysicalOptimizerRule`]s used to rewrite queries. + pub fn with_physical_optimizer_rule( + mut self, + physical_optimizer_rule: Arc, + ) -> Self { + let mut rules = self.physical_optimizer_rules.unwrap_or_default(); + rules.push(physical_optimizer_rule); + self.physical_optimizer_rules = Some(rules); + self + } + + /// Set the [`QueryPlanner`] + pub fn with_query_planner( + mut self, + query_planner: Arc, + ) -> Self { + self.query_planner = Some(query_planner); + self + } + + /// Set the [`CatalogProviderList`] + pub fn with_catalog_list( + mut self, + catalog_list: Arc, + ) -> Self { + self.catalog_list = Some(catalog_list); + self + } + + /// Set the map of [`TableFunction`]s + pub fn with_table_functions( + mut self, + table_functions: HashMap>, + ) -> Self { + self.table_functions = Some(table_functions); + self + } + + /// Set the map of [`ScalarUDF`]s + pub fn with_scalar_functions( + mut self, + scalar_functions: Vec>, + ) -> Self { + self.scalar_functions = Some(scalar_functions); + self + } + + /// Set the map of [`AggregateUDF`]s + pub fn with_aggregate_functions( + mut self, + aggregate_functions: Vec>, + ) -> Self { + self.aggregate_functions = Some(aggregate_functions); + self + } + + /// Set the map of [`WindowUDF`]s + pub fn with_window_functions( + mut self, + window_functions: Vec>, + ) -> Self { + self.window_functions = Some(window_functions); + self + } + + /// Set the [`SerializerRegistry`] + pub fn with_serializer_registry( + mut self, + serializer_registry: Arc, + ) -> Self { + self.serializer_registry = Some(serializer_registry); + self + } + + /// Set the map of [`FileFormatFactory`]s + pub fn with_file_formats( + mut self, + file_formats: Vec>, + ) -> Self { + self.file_formats = Some(file_formats); + self + } + + /// Set the [`SessionConfig`] + pub fn with_config(mut self, config: SessionConfig) -> Self { + self.config = Some(config); + self + } + + /// Set the [`TableOptions`] + pub fn with_table_options(mut self, table_options: TableOptions) -> Self { + self.table_options = Some(table_options); + self + } + + /// Set the [`ExecutionProps`] + pub fn with_execution_props(mut self, execution_props: ExecutionProps) -> Self { + self.execution_props = Some(execution_props); + self + } + + /// Set the map of [`TableProviderFactory`]s + pub fn with_table_factories( + mut self, + table_factories: HashMap>, + ) -> Self { + self.table_factories = Some(table_factories); + self + } + + /// Set the [`RuntimeEnv`] + pub fn with_runtime_env(mut self, runtime_env: Arc) -> Self { + self.runtime_env = Some(runtime_env); + self + } + + /// Set a [`FunctionFactory`] to handle `CREATE FUNCTION` statements + pub fn with_function_factory( + mut self, + function_factory: Option>, + ) -> Self { + self.function_factory = function_factory; + self + } + + /// Builds a [`SessionState`] with the current configuration. + /// + /// Note that there is an explicit option for enabling catalog and schema defaults + /// in [SessionConfig::create_default_catalog_and_schema] which if enabled + /// will be built here. + pub fn build(self) -> SessionState { + let Self { + session_id, + analyzer, + expr_planners, + optimizer, + physical_optimizers, + query_planner, + catalog_list, + table_functions, + scalar_functions, + aggregate_functions, + window_functions, + serializer_registry, + file_formats, + table_options, + config, + execution_props, + table_factories, + runtime_env, + function_factory, + analyzer_rules, + optimizer_rules, + physical_optimizer_rules, + } = self; + + let config = config.unwrap_or_default(); + let runtime_env = runtime_env.unwrap_or(Arc::new(RuntimeEnv::default())); + + let mut state = SessionState { + session_id: session_id.unwrap_or(Uuid::new_v4().to_string()), + analyzer: analyzer.unwrap_or_default(), + expr_planners: expr_planners.unwrap_or_default(), + optimizer: optimizer.unwrap_or_default(), + physical_optimizers: physical_optimizers.unwrap_or_default(), + query_planner: query_planner.unwrap_or(Arc::new(DefaultQueryPlanner {})), + catalog_list: catalog_list + .unwrap_or(Arc::new(MemoryCatalogProviderList::new()) + as Arc), + table_functions: table_functions.unwrap_or_default(), + scalar_functions: HashMap::new(), + aggregate_functions: HashMap::new(), + window_functions: HashMap::new(), + serializer_registry: serializer_registry + .unwrap_or(Arc::new(EmptySerializerRegistry)), + file_formats: HashMap::new(), + table_options: table_options + .unwrap_or(TableOptions::default_from_session_config(config.options())), + config, + execution_props: execution_props.unwrap_or_default(), + table_factories: table_factories.unwrap_or_default(), + runtime_env, + function_factory, + }; + + if let Some(file_formats) = file_formats { + for file_format in file_formats { + if let Err(e) = state.register_file_format(file_format, false) { + info!("Unable to register file format: {e}") + }; + } + } + + if let Some(scalar_functions) = scalar_functions { + scalar_functions.into_iter().for_each(|udf| { + let existing_udf = state.register_udf(udf); + if let Ok(Some(existing_udf)) = existing_udf { + debug!("Overwrote an existing UDF: {}", existing_udf.name()); + } + }); + } + + if let Some(aggregate_functions) = aggregate_functions { + aggregate_functions.into_iter().for_each(|udaf| { + let existing_udf = state.register_udaf(udaf); + if let Ok(Some(existing_udf)) = existing_udf { + debug!("Overwrote an existing UDF: {}", existing_udf.name()); + } + }); + } + + if let Some(window_functions) = window_functions { + window_functions.into_iter().for_each(|udwf| { + let existing_udf = state.register_udwf(udwf); + if let Ok(Some(existing_udf)) = existing_udf { + debug!("Overwrote an existing UDF: {}", existing_udf.name()); + } + }); + } + + if state.config.create_default_catalog_and_schema() { + let default_catalog = SessionStateDefaults::default_catalog( + &state.config, + &state.table_factories, + &state.runtime_env, + ); + + state.catalog_list.register_catalog( + state.config.options().catalog.default_catalog.clone(), + Arc::new(default_catalog), + ); + } + + if let Some(analyzer_rules) = analyzer_rules { + for analyzer_rule in analyzer_rules { + state.analyzer.rules.push(analyzer_rule); + } + } + + if let Some(optimizer_rules) = optimizer_rules { + for optimizer_rule in optimizer_rules { + state.optimizer.rules.push(optimizer_rule); + } + } + + if let Some(physical_optimizer_rules) = physical_optimizer_rules { + for physical_optimizer_rule in physical_optimizer_rules { + state + .physical_optimizers + .rules + .push(physical_optimizer_rule); + } + } + + state + } + + /// Returns the current session_id value + pub fn session_id(&self) -> &Option { + &self.session_id + } + + /// Returns the current analyzer value + pub fn analyzer(&mut self) -> &mut Option { + &mut self.analyzer + } + + /// Returns the current expr_planners value + pub fn expr_planners(&mut self) -> &mut Option>> { + &mut self.expr_planners + } + + /// Returns the current optimizer value + pub fn optimizer(&mut self) -> &mut Option { + &mut self.optimizer + } + + /// Returns the current physical_optimizers value + pub fn physical_optimizers(&mut self) -> &mut Option { + &mut self.physical_optimizers + } + + /// Returns the current query_planner value + pub fn query_planner(&mut self) -> &mut Option> { + &mut self.query_planner + } + + /// Returns the current catalog_list value + pub fn catalog_list(&mut self) -> &mut Option> { + &mut self.catalog_list + } + + /// Returns the current table_functions value + pub fn table_functions( + &mut self, + ) -> &mut Option>> { + &mut self.table_functions + } + + /// Returns the current scalar_functions value + pub fn scalar_functions(&mut self) -> &mut Option>> { + &mut self.scalar_functions + } + + /// Returns the current aggregate_functions value + pub fn aggregate_functions(&mut self) -> &mut Option>> { + &mut self.aggregate_functions + } + + /// Returns the current window_functions value + pub fn window_functions(&mut self) -> &mut Option>> { + &mut self.window_functions + } + + /// Returns the current serializer_registry value + pub fn serializer_registry(&mut self) -> &mut Option> { + &mut self.serializer_registry + } + + /// Returns the current file_formats value + pub fn file_formats(&mut self) -> &mut Option>> { + &mut self.file_formats + } + + /// Returns the current session_config value + pub fn config(&mut self) -> &mut Option { + &mut self.config + } + + /// Returns the current table_options value + pub fn table_options(&mut self) -> &mut Option { + &mut self.table_options + } + + /// Returns the current execution_props value + pub fn execution_props(&mut self) -> &mut Option { + &mut self.execution_props + } + + /// Returns the current table_factories value + pub fn table_factories( + &mut self, + ) -> &mut Option>> { + &mut self.table_factories + } + + /// Returns the current runtime_env value + pub fn runtime_env(&mut self) -> &mut Option> { + &mut self.runtime_env + } + + /// Returns the current function_factory value + pub fn function_factory(&mut self) -> &mut Option> { + &mut self.function_factory + } + + /// Returns the current analyzer_rules value + pub fn analyzer_rules( + &mut self, + ) -> &mut Option>> { + &mut self.analyzer_rules + } + + /// Returns the current optimizer_rules value + pub fn optimizer_rules( + &mut self, + ) -> &mut Option>> { + &mut self.optimizer_rules + } + + /// Returns the current physical_optimizer_rules value + pub fn physical_optimizer_rules( + &mut self, + ) -> &mut Option>> { + &mut self.physical_optimizer_rules + } +} + +impl Default for SessionStateBuilder { + fn default() -> Self { + Self::new() + } +} + +impl From for SessionStateBuilder { + fn from(state: SessionState) -> Self { + SessionStateBuilder::new_from_existing(state) + } +} + +/// Defaults that are used as part of creating a SessionState such as table providers, +/// file formats, registering of builtin functions, etc. +pub struct SessionStateDefaults {} + +impl SessionStateDefaults { + /// returns a map of the default [`TableProviderFactory`]s + pub fn default_table_factories() -> HashMap> { + let mut table_factories: HashMap> = + HashMap::new(); + #[cfg(feature = "parquet")] + table_factories.insert("PARQUET".into(), Arc::new(DefaultTableFactory::new())); + table_factories.insert("CSV".into(), Arc::new(DefaultTableFactory::new())); + table_factories.insert("JSON".into(), Arc::new(DefaultTableFactory::new())); + table_factories.insert("NDJSON".into(), Arc::new(DefaultTableFactory::new())); + table_factories.insert("AVRO".into(), Arc::new(DefaultTableFactory::new())); + table_factories.insert("ARROW".into(), Arc::new(DefaultTableFactory::new())); + + table_factories + } + + /// returns the default MemoryCatalogProvider + pub fn default_catalog( + config: &SessionConfig, + table_factories: &HashMap>, + runtime: &Arc, + ) -> MemoryCatalogProvider { + let default_catalog = MemoryCatalogProvider::new(); + + default_catalog + .register_schema( + &config.options().catalog.default_schema, + Arc::new(MemorySchemaProvider::new()), + ) + .expect("memory catalog provider can register schema"); + + Self::register_default_schema(config, table_factories, runtime, &default_catalog); + + default_catalog + } + + /// returns the list of default [`ExprPlanner`]s + pub fn default_expr_planners() -> Vec> { + let expr_planners: Vec> = vec![ + Arc::new(functions::core::planner::CoreFunctionPlanner::default()), + // register crate of array expressions (if enabled) + #[cfg(feature = "array_expressions")] + Arc::new(functions_array::planner::ArrayFunctionPlanner), + #[cfg(feature = "array_expressions")] + Arc::new(functions_array::planner::FieldAccessPlanner), + #[cfg(any( + feature = "datetime_expressions", + feature = "unicode_expressions" + ))] + Arc::new(functions::planner::UserDefinedFunctionPlanner), + ]; + + expr_planners + } + + /// returns the list of default [`ScalarUDF']'s + pub fn default_scalar_functions() -> Vec> { + let mut functions: Vec> = functions::all_default_functions(); + #[cfg(feature = "array_expressions")] + functions.append(&mut functions_array::all_default_array_functions()); + + functions + } + + /// returns the list of default [`AggregateUDF']'s + pub fn default_aggregate_functions() -> Vec> { + functions_aggregate::all_default_aggregate_functions() + } + + /// returns the list of default [`FileFormatFactory']'s + pub fn default_file_formats() -> Vec> { + let file_formats: Vec> = vec![ + #[cfg(feature = "parquet")] + Arc::new(ParquetFormatFactory::new()), + Arc::new(JsonFormatFactory::new()), + Arc::new(CsvFormatFactory::new()), + Arc::new(ArrowFormatFactory::new()), + Arc::new(AvroFormatFactory::new()), + ]; + + file_formats + } + + /// registers all builtin functions - scalar, array and aggregate + pub fn register_builtin_functions(state: &mut SessionState) { + Self::register_scalar_functions(state); + Self::register_array_functions(state); + Self::register_aggregate_functions(state); + } + + /// registers all the builtin scalar functions + pub fn register_scalar_functions(state: &mut SessionState) { + functions::register_all(state).expect("can not register built in functions"); + } + + /// registers all the builtin array functions + pub fn register_array_functions(state: &mut SessionState) { + // register crate of array expressions (if enabled) + #[cfg(feature = "array_expressions")] + functions_array::register_all(state).expect("can not register array expressions"); + } + + /// registers all the builtin aggregate functions + pub fn register_aggregate_functions(state: &mut SessionState) { + functions_aggregate::register_all(state) + .expect("can not register aggregate functions"); + } + + /// registers the default schema + pub fn register_default_schema( + config: &SessionConfig, + table_factories: &HashMap>, + runtime: &Arc, + default_catalog: &MemoryCatalogProvider, + ) { + let url = config.options().catalog.location.as_ref(); + let format = config.options().catalog.format.as_ref(); + let (url, format) = match (url, format) { + (Some(url), Some(format)) => (url, format), + _ => return, + }; + let url = url.to_string(); + let format = format.to_string(); + + let url = Url::parse(url.as_str()).expect("Invalid default catalog location!"); + let authority = match url.host_str() { + Some(host) => format!("{}://{}", url.scheme(), host), + None => format!("{}://", url.scheme()), + }; + let path = &url.as_str()[authority.len()..]; + let path = object_store::path::Path::parse(path).expect("Can't parse path"); + let store = ObjectStoreUrl::parse(authority.as_str()) + .expect("Invalid default catalog url"); + let store = match runtime.object_store(store) { + Ok(store) => store, + _ => return, + }; + let factory = match table_factories.get(format.as_str()) { + Some(factory) => factory, + _ => return, + }; + let schema = + ListingSchemaProvider::new(authority, path, factory.clone(), store, format); + let _ = default_catalog + .register_schema("default", Arc::new(schema)) + .expect("Failed to register default schema"); + } + + /// registers the default [`FileFormatFactory`]s + pub fn register_default_file_formats(state: &mut SessionState) { + let formats = SessionStateDefaults::default_file_formats(); + for format in formats { + if let Err(e) = state.register_file_format(format, false) { + log::info!("Unable to register default file format: {e}") + }; + } + } +} + struct SessionContextProvider<'a> { state: &'a SessionState, tables: HashMap>, diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index d2bc334ec3248..efc83d8f6b5c2 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -2269,6 +2269,7 @@ mod tests { use crate::prelude::{SessionConfig, SessionContext}; use crate::test_util::{scan_empty, scan_empty_with_partitions}; + use crate::execution::session_state::SessionStateBuilder; use arrow::array::{ArrayRef, DictionaryArray, Int32Array}; use arrow::datatypes::{DataType, Field, Int32Type}; use datafusion_common::{assert_contains, DFSchemaRef, TableReference}; @@ -2282,7 +2283,11 @@ mod tests { let runtime = Arc::new(RuntimeEnv::default()); let config = SessionConfig::new().with_target_partitions(4); let config = config.set_bool("datafusion.optimizer.skip_failed_rules", false); - SessionState::new_with_config_rt(config, runtime) + SessionStateBuilder::new() + .with_config(config) + .with_runtime_env(runtime) + .with_default_features() + .build() } async fn plan(logical_plan: &LogicalPlan) -> Result> { diff --git a/datafusion/core/src/test/object_store.rs b/datafusion/core/src/test/object_store.rs index bea6f7b9ceb7b..6c0a2fc7bec47 100644 --- a/datafusion/core/src/test/object_store.rs +++ b/datafusion/core/src/test/object_store.rs @@ -16,9 +16,8 @@ // under the License. //! Object store implementation used for testing use crate::execution::context::SessionState; +use crate::execution::session_state::SessionStateBuilder; use crate::prelude::SessionContext; -use datafusion_execution::config::SessionConfig; -use datafusion_execution::runtime_env::RuntimeEnv; use futures::FutureExt; use object_store::{memory::InMemory, path::Path, ObjectMeta, ObjectStore}; use std::sync::Arc; @@ -44,10 +43,7 @@ pub fn make_test_store_and_state(files: &[(&str, u64)]) -> (Arc, Sessi ( Arc::new(memory), - SessionState::new_with_config_rt( - SessionConfig::default(), - Arc::new(RuntimeEnv::default()), - ), + SessionStateBuilder::new().with_default_features().build(), ) } diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index f1d57c44293be..1b2a6770cf013 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -42,7 +42,8 @@ use url::Url; use datafusion::dataframe::{DataFrame, DataFrameWriteOptions}; use datafusion::datasource::MemTable; use datafusion::error::Result; -use datafusion::execution::context::{SessionContext, SessionState}; +use datafusion::execution::context::SessionContext; +use datafusion::execution::session_state::SessionStateBuilder; use datafusion::prelude::JoinType; use datafusion::prelude::{CsvReadOptions, ParquetReadOptions}; use datafusion::test_util::{parquet_test_data, populate_csv_partitions}; @@ -1544,7 +1545,11 @@ async fn unnest_non_nullable_list() -> Result<()> { async fn test_read_batches() -> Result<()> { let config = SessionConfig::new(); let runtime = Arc::new(RuntimeEnv::default()); - let state = SessionState::new_with_config_rt(config, runtime); + let state = SessionStateBuilder::new() + .with_config(config) + .with_runtime_env(runtime) + .with_default_features() + .build(); let ctx = SessionContext::new_with_state(state); let schema = Arc::new(Schema::new(vec![ @@ -1594,7 +1599,11 @@ async fn test_read_batches() -> Result<()> { async fn test_read_batches_empty() -> Result<()> { let config = SessionConfig::new(); let runtime = Arc::new(RuntimeEnv::default()); - let state = SessionState::new_with_config_rt(config, runtime); + let state = SessionStateBuilder::new() + .with_config(config) + .with_runtime_env(runtime) + .with_default_features() + .build(); let ctx = SessionContext::new_with_state(state); let batches = vec![]; @@ -1608,9 +1617,7 @@ async fn test_read_batches_empty() -> Result<()> { #[tokio::test] async fn consecutive_projection_same_schema() -> Result<()> { - let config = SessionConfig::new(); - let runtime = Arc::new(RuntimeEnv::default()); - let state = SessionState::new_with_config_rt(config, runtime); + let state = SessionStateBuilder::new().with_default_features().build(); let ctx = SessionContext::new_with_state(state); let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)])); diff --git a/datafusion/core/tests/memory_limit/mod.rs b/datafusion/core/tests/memory_limit/mod.rs index 7ef24609e238d..1d151f9fd3683 100644 --- a/datafusion/core/tests/memory_limit/mod.rs +++ b/datafusion/core/tests/memory_limit/mod.rs @@ -38,6 +38,7 @@ use datafusion::datasource::{MemTable, TableProvider}; use datafusion::execution::context::SessionState; use datafusion::execution::disk_manager::DiskManagerConfig; use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; +use datafusion::execution::session_state::SessionStateBuilder; use datafusion::physical_optimizer::join_selection::JoinSelection; use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use datafusion::physical_plan::{ExecutionPlan, SendableRecordBatchStream}; @@ -459,13 +460,16 @@ impl TestCase { let runtime = RuntimeEnv::new(rt_config).unwrap(); // Configure execution - let state = SessionState::new_with_config_rt(config, Arc::new(runtime)); - let state = match scenario.rules() { - Some(rules) => state.with_physical_optimizer_rules(rules), - None => state, + let builder = SessionStateBuilder::new() + .with_config(config) + .with_runtime_env(Arc::new(runtime)) + .with_default_features(); + let builder = match scenario.rules() { + Some(rules) => builder.with_physical_optimizer_rules(rules), + None => builder, }; - let ctx = SessionContext::new_with_state(state); + let ctx = SessionContext::new_with_state(builder.build()); ctx.register_table("t", table).expect("registering table"); let query = query.expect("Test error: query not specified"); diff --git a/datafusion/core/tests/parquet/file_statistics.rs b/datafusion/core/tests/parquet/file_statistics.rs index 9f94a59a3e598..bf25b36f48e8b 100644 --- a/datafusion/core/tests/parquet/file_statistics.rs +++ b/datafusion/core/tests/parquet/file_statistics.rs @@ -35,6 +35,7 @@ use datafusion_execution::cache::cache_unit::{ use datafusion_execution::config::SessionConfig; use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv}; +use datafusion::execution::session_state::SessionStateBuilder; use tempfile::tempdir; #[tokio::test] @@ -167,10 +168,7 @@ async fn get_listing_table( ) -> ListingTable { let schema = opt .infer_schema( - &SessionState::new_with_config_rt( - SessionConfig::default(), - Arc::new(RuntimeEnv::default()), - ), + &SessionStateBuilder::new().with_default_features().build(), table_path, ) .await diff --git a/datafusion/core/tests/sql/create_drop.rs b/datafusion/core/tests/sql/create_drop.rs index 2174009b85573..83712053b9542 100644 --- a/datafusion/core/tests/sql/create_drop.rs +++ b/datafusion/core/tests/sql/create_drop.rs @@ -15,18 +15,14 @@ // specific language governing permissions and limitations // under the License. -use datafusion::execution::context::SessionState; -use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; +use datafusion::execution::session_state::SessionStateBuilder; use datafusion::test_util::TestTableFactory; use super::*; #[tokio::test] async fn create_custom_table() -> Result<()> { - let cfg = RuntimeConfig::new(); - let env = RuntimeEnv::new(cfg).unwrap(); - let ses = SessionConfig::new(); - let mut state = SessionState::new_with_config_rt(ses, Arc::new(env)); + let mut state = SessionStateBuilder::new().with_default_features().build(); state .table_factories_mut() .insert("DELTATABLE".to_string(), Arc::new(TestTableFactory {})); @@ -45,10 +41,7 @@ async fn create_custom_table() -> Result<()> { #[tokio::test] async fn create_external_table_with_ddl() -> Result<()> { - let cfg = RuntimeConfig::new(); - let env = RuntimeEnv::new(cfg).unwrap(); - let ses = SessionConfig::new(); - let mut state = SessionState::new_with_config_rt(ses, Arc::new(env)); + let mut state = SessionStateBuilder::new().with_default_features().build(); state .table_factories_mut() .insert("MOCKTABLE".to_string(), Arc::new(TestTableFactory {})); diff --git a/datafusion/core/tests/user_defined/user_defined_plan.rs b/datafusion/core/tests/user_defined/user_defined_plan.rs index 38ed142cf922f..a44f522ba95ac 100644 --- a/datafusion/core/tests/user_defined/user_defined_plan.rs +++ b/datafusion/core/tests/user_defined/user_defined_plan.rs @@ -92,6 +92,7 @@ use datafusion::{ }; use async_trait::async_trait; +use datafusion::execution::session_state::SessionStateBuilder; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::ScalarValue; @@ -290,10 +291,14 @@ async fn topk_plan() -> Result<()> { fn make_topk_context() -> SessionContext { let config = SessionConfig::new().with_target_partitions(48); let runtime = Arc::new(RuntimeEnv::default()); - let mut state = SessionState::new_with_config_rt(config, runtime) + let state = SessionStateBuilder::new() + .with_config(config) + .with_runtime_env(runtime) + .with_default_features() .with_query_planner(Arc::new(TopKQueryPlanner {})) - .add_optimizer_rule(Arc::new(TopKOptimizerRule {})); - state.add_analyzer_rule(Arc::new(MyAnalyzerRule {})); + .with_optimizer_rule(Arc::new(TopKOptimizerRule {})) + .with_analyzer_rule(Arc::new(MyAnalyzerRule {})) + .build(); SessionContext::new_with_state(state) } diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index f764a050a6cdd..d0209d811b7ce 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -39,8 +39,7 @@ use prost::Message; use datafusion::datasource::provider::TableProviderFactory; use datafusion::datasource::TableProvider; -use datafusion::execution::context::SessionState; -use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; +use datafusion::execution::session_state::SessionStateBuilder; use datafusion::execution::FunctionRegistry; use datafusion::functions_aggregate::count::count_udaf; use datafusion::functions_aggregate::expr_fn::{ @@ -202,10 +201,7 @@ async fn roundtrip_custom_tables() -> Result<()> { let mut table_factories: HashMap> = HashMap::new(); table_factories.insert("TESTTABLE".to_string(), Arc::new(TestTableFactory {})); - let cfg = RuntimeConfig::new(); - let env = RuntimeEnv::new(cfg).unwrap(); - let ses = SessionConfig::new(); - let mut state = SessionState::new_with_config_rt(ses, Arc::new(env)); + let mut state = SessionStateBuilder::new().with_default_features().build(); // replace factories *state.table_factories_mut() = table_factories; let ctx = SessionContext::new_with_state(state); diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index 2893b1a31a26c..5b2d0fbacaef0 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -28,7 +28,6 @@ use std::sync::Arc; use datafusion::arrow::datatypes::{DataType, Field, IntervalUnit, Schema, TimeUnit}; use datafusion::common::{not_impl_err, plan_err, DFSchema, DFSchemaRef}; use datafusion::error::Result; -use datafusion::execution::context::SessionState; use datafusion::execution::registry::SerializerRegistry; use datafusion::execution::runtime_env::RuntimeEnv; use datafusion::logical_expr::{ @@ -37,6 +36,7 @@ use datafusion::logical_expr::{ use datafusion::optimizer::simplify_expressions::expr_simplifier::THRESHOLD_INLINE_INLIST; use datafusion::prelude::*; +use datafusion::execution::session_state::SessionStateBuilder; use substrait::proto::extensions::simple_extension_declaration::MappingType; use substrait::proto::rel::RelType; use substrait::proto::{plan_rel, Plan, Rel}; @@ -1121,11 +1121,12 @@ async fn function_extension_info(sql: &str) -> Result<(Vec, Vec)> { } async fn create_context() -> Result { - let mut state = SessionState::new_with_config_rt( - SessionConfig::default(), - Arc::new(RuntimeEnv::default()), - ) - .with_serializer_registry(Arc::new(MockSerializerRegistry)); + let mut state = SessionStateBuilder::new() + .with_config(SessionConfig::default()) + .with_runtime_env(Arc::new(RuntimeEnv::default())) + .with_default_features() + .with_serializer_registry(Arc::new(MockSerializerRegistry)) + .build(); // register udaf for test, e.g. `sum()` datafusion_functions_aggregate::register_all(&mut state)